achase25 commited on
Commit
fc12805
·
verified ·
1 Parent(s): 8549faf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -142
app.py CHANGED
@@ -1,22 +1,22 @@
1
- # app.py Multimodal router: one image input + freeform command -> text OR image output
2
- # Commands (examples):
3
- # "describe the photo" -> text caption
4
- # "write a story about the image" -> text story
5
- # "make the photo look like a cartoon" -> image stylization
6
  #
7
- # Dependencies / requirements.txt:
8
- # pip install -q gradio transformers diffusers accelerate torch safetensors pillow
 
 
 
 
 
9
 
10
  import os
11
- import re
12
- import random
13
- from typing import Optional, Tuple
14
 
15
  import torch
16
  import gradio as gr
17
  from PIL import Image
18
 
19
- # ---- Transformers: caption + story ----
20
  from transformers import (
21
  VisionEncoderDecoderModel,
22
  AutoImageProcessor,
@@ -24,38 +24,22 @@ from transformers import (
24
  pipeline as hf_pipeline,
25
  )
26
 
27
- # ---- Diffusers: image-to-image stylization ----
28
- from diffusers import StableDiffusionImg2ImgPipeline
29
-
30
- # ------------- Config -------------
31
- CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "nlpconnect/vit-gpt2-image-captioning")
32
- STORY_MODEL_ID = os.getenv("STORY_MODEL_ID", "google/flan-t5-large") # light-ish; ok stories
33
- IMG2IMG_MODEL_ID = os.getenv("IMG2IMG_MODEL_ID", "stabilityai/stable-diffusion-2-1")
34
-
35
- MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "768")) # clamp big uploads to save VRAM
36
- DEFAULT_STEPS = int(os.getenv("STEPS", "30"))
37
- DEFAULT_GUIDANCE = float(os.getenv("GUIDANCE", "7.5"))
38
- DEFAULT_STRENGTH = float(os.getenv("STRENGTH", "0.6")) # 0..1 (higher = more stylized, less like original)
39
 
40
- DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
41
- DTYPE = torch.float16 if (DEVICE == "cuda") else torch.float32
 
42
 
43
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
44
 
45
- # ---- Space/runtime feature flags ----
46
- CARTOON_AVAILABLE = torch.cuda.is_available() # SD img2img is GPU-only on Spaces (CPU will timeout)
47
-
48
- # CPU-friendly fallbacks (keep things snappy on Spaces CPU)
49
- if not CARTOON_AVAILABLE:
50
- DEFAULT_STEPS = min(DEFAULT_STEPS, 20)
51
- DEFAULT_GUIDANCE = min(DEFAULT_GUIDANCE, 7.5)
52
-
53
- # ------------- Caches -------------
54
  _caption_bundle = {}
55
  _story_pipe = None
56
- _img2img_pipe = None
57
 
58
- # ------------- Utils -------------
59
  def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
60
  w, h = img.size
61
  if max(w, h) <= max_side:
@@ -66,72 +50,56 @@ def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
66
  else:
67
  new_h = max_side
68
  new_w = int(w * (max_side / h))
69
- return img.resize((new_w // 8 * 8, new_h // 8 * 8), Image.LANCZOS) # multiples of 8 for SD
70
-
71
- def _seeded_generator(seed: Optional[int]):
72
- if seed is None or str(seed).strip() == "":
73
- return None
74
- try:
75
- seed = int(seed)
76
- except Exception:
77
- return None
78
- dev = "cuda" if DEVICE == "cuda" else "cpu"
79
- return torch.Generator(device=dev).manual_seed(seed)
80
 
81
- # ------------- Loaders -------------
82
  def get_caption_bundle():
 
83
  global _caption_bundle
84
  if _caption_bundle:
85
  return _caption_bundle
 
86
  processor = AutoImageProcessor.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
 
87
  tokenizer = AutoTokenizer.from_pretrained(CAPTION_MODEL_ID, use_fast=True, token=HF_TOKEN)
88
  model = VisionEncoderDecoderModel.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
89
- # GPT2 has no pad by default -> set pad=eos to avoid mask issues
 
90
  if tokenizer.pad_token is None:
91
  tokenizer.pad_token = tokenizer.eos_token
92
  model.config.pad_token_id = tokenizer.pad_token_id
93
  model.config.eos_token_id = tokenizer.eos_token_id
94
  if getattr(model.config, "decoder_start_token_id", None) is None and tokenizer.bos_token_id is not None:
95
  model.config.decoder_start_token_id = tokenizer.bos_token_id
 
96
  model.to(DEVICE).eval()
 
97
  _caption_bundle = {"processor": processor, "tokenizer": tokenizer, "model": model}
98
  return _caption_bundle
99
 
100
  def get_story_pipe():
 
101
  global _story_pipe
102
  if _story_pipe is not None:
103
  return _story_pipe
104
- # Flan-T5 works with text2text-generation
105
- _story_pipe = hf_pipeline("text2text-generation", model=STORY_MODEL_ID, device_map="auto", model_kwargs={"torch_dtype": DTYPE})
106
- return _story_pipe
107
-
108
- def get_img2img_pipe():
109
- global _img2img_pipe
110
- if _img2img_pipe is not None:
111
- return _img2img_pipe
112
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
113
- IMG2IMG_MODEL_ID,
114
- torch_dtype=DTYPE,
115
- safety_checker=None, # flip to enable safety if you prefer
116
- requires_safety_checker=False,
117
- use_safetensors=True,
118
  )
119
- pipe = pipe.to(DEVICE)
120
- try:
121
- pipe.enable_xformers_memory_efficient_attention()
122
- except Exception:
123
- pass
124
- _img2img_pipe = pipe
125
- return _img2img_pipe
126
 
127
- # ------------- Ops -------------
128
  @torch.inference_mode()
129
  def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4) -> str:
130
- bundle = get_caption_bundle()
131
- proc, tok, mdl = bundle["processor"], bundle["tokenizer"], bundle["model"]
132
- pv = proc(image.convert("RGB"), return_tensors="pt").pixel_values.to(DEVICE)
 
 
133
  out = mdl.generate(
134
- pixel_values=pv,
135
  max_new_tokens=max_new_tokens,
136
  num_beams=num_beams,
137
  pad_token_id=tok.pad_token_id,
@@ -142,13 +110,13 @@ def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4)
142
  def op_story(
143
  image: Image.Image,
144
  num_sentences: int = 5,
145
- max_new_tokens: int = 220, # allow enough room
146
- min_new_tokens: int = 80, # force >= ~80 tokens (~5 sentences)
147
  temperature: float = 0.9,
148
  top_p: float = 0.92,
149
  no_repeat_ngram_size: int = 3,
150
  ) -> str:
151
- # Ground the story with a caption of the image
152
  caption = op_caption(image)
153
 
154
  prompt = (
@@ -157,20 +125,20 @@ def op_story(
157
  f"Image description: {caption}\n\nStory:"
158
  )
159
 
160
- story_pipe = get_story_pipe()
161
- out = story_pipe(
162
  prompt,
163
  do_sample=True,
164
  temperature=temperature,
165
  top_p=top_p,
166
- min_new_tokens=min_new_tokens, # <- prevents early stop
167
  max_new_tokens=max_new_tokens,
168
  no_repeat_ngram_size=no_repeat_ngram_size,
169
  num_return_sequences=1,
170
  )
171
  text = out[0]["generated_text"].strip()
172
 
173
- # Safety belt: hard-trim to exactly N sentences
174
  import re
175
  sents = re.split(r'(?<=[.!?])\s+', text)
176
  sents = [s.strip() for s in sents if s.strip()]
@@ -178,80 +146,34 @@ def op_story(
178
  text = " ".join(sents[:num_sentences])
179
  return text
180
 
181
-
182
-
183
- @torch.inference_mode()
184
- def op_cartoon(image: Image.Image, steps=DEFAULT_STEPS, guidance=DEFAULT_GUIDANCE, strength=DEFAULT_STRENGTH, seed: Optional[int]=None):
185
- img = _resize_max(image.convert("RGB"))
186
- gen = _seeded_generator(seed)
187
- pipe = get_img2img_pipe()
188
- prompt = "cartoon, cel-shaded, flat colors, bold outlines, clean lineart, anime style, comic book"
189
- negative = "photorealistic, blurry, noisy, artifacts, distorted, watermark"
190
- result = pipe(
191
- prompt=prompt,
192
- negative_prompt=negative,
193
- image=img,
194
- strength=float(strength),
195
- guidance_scale=float(guidance),
196
- num_inference_steps=int(steps),
197
- generator=gen,
198
- )
199
- return result.images[0]
200
-
201
- # ------------- Router -------------
202
- def route_command(command: str) -> str:
203
- c = (command or "").lower()
204
- if any(k in c for k in ["cartoon", "sketch", "comic", "anime", "illustration"]):
205
- return "cartoon"
206
- if any(k in c for k in ["story", "poem", "narrative", "write"]):
207
- return "story"
208
- # default / describe / caption / explain
209
- return "caption"
210
-
211
- # ------------- Gradio App -------------
212
- def run(image: Image.Image, command: str, steps: int, guidance: float, strength: float, seed: str):
213
  if image is None:
214
- raise gr.Error("Upload an image.")
215
- mode = route_command(command)
216
-
217
- if mode == "cartoon":
218
- if not CARTOON_AVAILABLE:
219
- raise gr.Error("Cartoon mode requires a GPU and is disabled on this Space’s hardware.")
220
- img = op_cartoon(
221
- image,
222
- steps=steps,
223
- guidance=guidance,
224
- strength=strength,
225
- seed=int(seed) if seed else None,
226
- )
227
- return None, img, f"Mode: cartoon ({steps} steps, guidance {guidance}, strength {strength}, seed {seed or 'None'})"
228
- elif mode == "story":
229
  txt = op_story(image)
230
  return txt, None, "Mode: story"
231
  else:
232
  txt = op_caption(image)
233
  return txt, None, "Mode: caption"
234
 
235
-
236
  with gr.Blocks(css="footer {visibility:hidden}") as demo:
237
- gr.Markdown("# Image Command Router describe • cartoonize • write a story")
238
  with gr.Row():
239
  with gr.Column():
240
  inp_img = gr.Image(type="pil", label="Image")
241
- inp_cmd = gr.Textbox(label="Command", placeholder='e.g., "describe the photo", "make the photo look like a cartoon", "write a story about the image"', lines=2, value="describe the photo")
242
- with gr.Accordion("Advanced (cartoon mode)", open=False, visible=CARTOON_AVAILABLE):
243
- steps = gr.Slider(1, 75, value=DEFAULT_STEPS, step=1, label="Steps")
244
- guidance = gr.Slider(0.0, 15.0, value=DEFAULT_GUIDANCE, step=0.1, label="Guidance (CFG)")
245
- strength = gr.Slider(0.1, 1.0, value=DEFAULT_STRENGTH, step=0.05, label="Strength (how much to change)")
246
- seed = gr.Textbox(value="", label="Seed (optional int)")
247
  go = gr.Button("Run", variant="primary")
248
  with gr.Column():
249
  out_text = gr.Textbox(label="Text output", lines=10)
250
- out_image = gr.Image(label="Image output")
251
  status = gr.Markdown()
252
- go.click(run, inputs=[inp_img, inp_cmd, steps, guidance, strength, seed], outputs=[out_text, out_image, status], scroll_to_output=True)
253
 
254
  if __name__ == "__main__":
255
- # queue() helps Spaces handle concurrent requests + long inference safely
256
  demo.queue(max_size=8).launch()
257
-
 
1
+ # CPU-only Hugging Face Space: Image -> (Caption OR Story)
2
+ # - Caption: Salesforce BLIP or ViT-GPT2 (set via env or leave defaults)
3
+ # - Story: text2text generation using a lightweight T5-family model
 
 
4
  #
5
+ # Env (optional):
6
+ # CAPTION_MODEL_ID = "Salesforce/blip-image-captioning-large" (heavier, better)
7
+ # or "nlpconnect/vit-gpt2-image-captioning" (lighter, faster on CPU)
8
+ # STORY_MODEL_ID = "google/flan-t5-large" (default, decent on CPU)
9
+ # HUGGINGFACE_HUB_TOKEN / HF_TOKEN (if models require auth)
10
+ #
11
+ # Requirements are in requirements.txt.
12
 
13
  import os
14
+ from typing import Optional
 
 
15
 
16
  import torch
17
  import gradio as gr
18
  from PIL import Image
19
 
 
20
  from transformers import (
21
  VisionEncoderDecoderModel,
22
  AutoImageProcessor,
 
24
  pipeline as hf_pipeline,
25
  )
26
 
27
+ # -------------------- Config --------------------
28
+ CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "Salesforce/blip-image-captioning-large")
29
+ STORY_MODEL_ID = os.getenv("STORY_MODEL_ID", "google/flan-t5-large")
30
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
 
 
 
 
 
 
 
 
31
 
32
+ # CPU only (works on Spaces without GPU)
33
+ DEVICE = "cpu"
34
+ DTYPE = torch.float32
35
 
36
+ MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "768")) # clamp inputs to keep it snappy
37
 
38
+ # -------------------- Caches --------------------
 
 
 
 
 
 
 
 
39
  _caption_bundle = {}
40
  _story_pipe = None
 
41
 
42
+ # -------------------- Utils --------------------
43
  def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
44
  w, h = img.size
45
  if max(w, h) <= max_side:
 
50
  else:
51
  new_h = max_side
52
  new_w = int(w * (max_side / h))
53
+ return img.resize((new_w, new_h), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # -------------------- Loaders --------------------
56
  def get_caption_bundle():
57
+ """Load a vision->text captioning model (BLIP or ViT-GPT2 family) with sane tokenizer settings."""
58
  global _caption_bundle
59
  if _caption_bundle:
60
  return _caption_bundle
61
+
62
  processor = AutoImageProcessor.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
63
+ # Use fast tokenizer when available to silence 'use_fast' warnings
64
  tokenizer = AutoTokenizer.from_pretrained(CAPTION_MODEL_ID, use_fast=True, token=HF_TOKEN)
65
  model = VisionEncoderDecoderModel.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
66
+
67
+ # GPT2 lacks pad by default; set to eos and mirror in config to avoid attention_mask warnings
68
  if tokenizer.pad_token is None:
69
  tokenizer.pad_token = tokenizer.eos_token
70
  model.config.pad_token_id = tokenizer.pad_token_id
71
  model.config.eos_token_id = tokenizer.eos_token_id
72
  if getattr(model.config, "decoder_start_token_id", None) is None and tokenizer.bos_token_id is not None:
73
  model.config.decoder_start_token_id = tokenizer.bos_token_id
74
+
75
  model.to(DEVICE).eval()
76
+
77
  _caption_bundle = {"processor": processor, "tokenizer": tokenizer, "model": model}
78
  return _caption_bundle
79
 
80
  def get_story_pipe():
81
+ """Lightweight text2text pipeline for story generation."""
82
  global _story_pipe
83
  if _story_pipe is not None:
84
  return _story_pipe
85
+ _story_pipe = hf_pipeline(
86
+ "text2text-generation",
87
+ model=STORY_MODEL_ID,
88
+ device=-1, # CPU
89
+ model_kwargs={"torch_dtype": DTYPE},
 
 
 
 
 
 
 
 
 
90
  )
91
+ return _story_pipe
 
 
 
 
 
 
92
 
93
+ # -------------------- Ops --------------------
94
  @torch.inference_mode()
95
  def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4) -> str:
96
+ b = get_caption_bundle()
97
+ proc, tok, mdl = b["processor"], b["tokenizer"], b["model"]
98
+ image = _resize_max(image.convert("RGB"))
99
+ pixel_values = proc(image, return_tensors="pt").pixel_values.to(DEVICE)
100
+
101
  out = mdl.generate(
102
+ pixel_values=pixel_values,
103
  max_new_tokens=max_new_tokens,
104
  num_beams=num_beams,
105
  pad_token_id=tok.pad_token_id,
 
110
  def op_story(
111
  image: Image.Image,
112
  num_sentences: int = 5,
113
+ max_new_tokens: int = 220,
114
+ min_new_tokens: int = 80,
115
  temperature: float = 0.9,
116
  top_p: float = 0.92,
117
  no_repeat_ngram_size: int = 3,
118
  ) -> str:
119
+ # Ground with a caption first
120
  caption = op_caption(image)
121
 
122
  prompt = (
 
125
  f"Image description: {caption}\n\nStory:"
126
  )
127
 
128
+ pipe = get_story_pipe()
129
+ out = pipe(
130
  prompt,
131
  do_sample=True,
132
  temperature=temperature,
133
  top_p=top_p,
134
+ min_new_tokens=min_new_tokens, # prevents early stop at 1 sentence
135
  max_new_tokens=max_new_tokens,
136
  no_repeat_ngram_size=no_repeat_ngram_size,
137
  num_return_sequences=1,
138
  )
139
  text = out[0]["generated_text"].strip()
140
 
141
+ # Trim to exactly N sentences (safety belt)
142
  import re
143
  sents = re.split(r'(?<=[.!?])\s+', text)
144
  sents = [s.strip() for s in sents if s.strip()]
 
146
  text = " ".join(sents[:num_sentences])
147
  return text
148
 
149
+ # -------------------- Gradio UI --------------------
150
+ def run(image: Image.Image, mode: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  if image is None:
152
+ raise gr.Error("Upload an image first.")
153
+ mode = (mode or "Caption").lower()
154
+ if mode == "story":
 
 
 
 
 
 
 
 
 
 
 
 
155
  txt = op_story(image)
156
  return txt, None, "Mode: story"
157
  else:
158
  txt = op_caption(image)
159
  return txt, None, "Mode: caption"
160
 
 
161
  with gr.Blocks(css="footer {visibility:hidden}") as demo:
162
+ gr.Markdown("# Image Caption or Story (CPU-only)")
163
  with gr.Row():
164
  with gr.Column():
165
  inp_img = gr.Image(type="pil", label="Image")
166
+ mode = gr.Radio(
167
+ choices=["Caption", "Story"],
168
+ value="Caption",
169
+ label="Task",
170
+ )
 
171
  go = gr.Button("Run", variant="primary")
172
  with gr.Column():
173
  out_text = gr.Textbox(label="Text output", lines=10)
174
+ out_image = gr.Image(label="(unused for CPU app)", visible=False)
175
  status = gr.Markdown()
176
+ go.click(run, inputs=[inp_img, mode], outputs=[out_text, out_image, status], scroll_to_output=True)
177
 
178
  if __name__ == "__main__":
 
179
  demo.queue(max_size=8).launch()