achase25 commited on
Commit
e6ef52a
·
verified ·
1 Parent(s): 9780798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -59
app.py CHANGED
@@ -1,14 +1,22 @@
1
- # app.py — Image+Command router: "describe photo" (caption), "write a story" (text), "make it a cartoon" (img2img)
2
- # Deps:
 
 
 
 
 
3
  # pip install -q gradio transformers diffusers accelerate torch safetensors pillow
 
4
  import os
5
  import re
6
- from typing import Optional
 
7
 
8
  import torch
9
  import gradio as gr
10
  from PIL import Image
11
 
 
12
  from transformers import (
13
  VisionEncoderDecoderModel,
14
  AutoImageProcessor,
@@ -16,30 +24,30 @@ from transformers import (
16
  pipeline as hf_pipeline,
17
  )
18
 
 
19
  from diffusers import StableDiffusionImg2ImgPipeline
20
 
21
- # ----------------- Config -----------------
22
  CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "nlpconnect/vit-gpt2-image-captioning")
23
- # For longer/better stories you can set: google/flan-t5-xl (needs ~10–12GB VRAM) or google/flan-ul2 (heavy)
24
- STORY_MODEL_ID = os.getenv("STORY_MODEL_ID", "google/flan-t5-large")
25
  IMG2IMG_MODEL_ID = os.getenv("IMG2IMG_MODEL_ID", "stabilityai/stable-diffusion-2-1")
26
 
27
- MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "768"))
28
  DEFAULT_STEPS = int(os.getenv("STEPS", "30"))
29
  DEFAULT_GUIDANCE = float(os.getenv("GUIDANCE", "7.5"))
30
- DEFAULT_STRENGTH = float(os.getenv("STRENGTH", "0.6"))
31
 
32
  DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
33
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
34
 
35
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
36
 
37
- # ----------------- Caches -----------------
38
  _caption_bundle = {}
39
  _story_pipe = None
40
  _img2img_pipe = None
41
 
42
- # ----------------- Utils -----------------
43
  def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
44
  w, h = img.size
45
  if max(w, h) <= max_side:
@@ -50,8 +58,7 @@ def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
50
  else:
51
  new_h = max_side
52
  new_w = int(w * (max_side / h))
53
- # Snap to multiples of 8 for SD pipelines
54
- return img.resize((new_w // 8 * 8, new_h // 8 * 8), Image.LANCZOS)
55
 
56
  def _seeded_generator(seed: Optional[int]):
57
  if seed is None or str(seed).strip() == "":
@@ -63,35 +70,21 @@ def _seeded_generator(seed: Optional[int]):
63
  dev = "cuda" if DEVICE == "cuda" else "cpu"
64
  return torch.Generator(device=dev).manual_seed(seed)
65
 
66
- def parse_num_sentences(cmd: str, default: int = 5) -> int:
67
- m = re.search(r"(\d+)\s*(?:sentences?|sentence)", (cmd or "").lower())
68
- if m:
69
- try:
70
- n = int(m.group(1))
71
- return max(1, min(n, 20)) # keep sane bounds
72
- except Exception:
73
- pass
74
- return default
75
-
76
- # ----------------- Loaders -----------------
77
  def get_caption_bundle():
78
  global _caption_bundle
79
  if _caption_bundle:
80
  return _caption_bundle
81
- # use_fast=True avoids “slow processor/tokenizer” warnings
82
  processor = AutoImageProcessor.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
83
  tokenizer = AutoTokenizer.from_pretrained(CAPTION_MODEL_ID, use_fast=True, token=HF_TOKEN)
84
  model = VisionEncoderDecoderModel.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
85
-
86
- # GPT-2 decoders have no pad by default -> set pad=eos; set ids so generate() is happy
87
  if tokenizer.pad_token is None:
88
  tokenizer.pad_token = tokenizer.eos_token
89
- tokenizer.padding_side = "right"
90
  model.config.pad_token_id = tokenizer.pad_token_id
91
  model.config.eos_token_id = tokenizer.eos_token_id
92
  if getattr(model.config, "decoder_start_token_id", None) is None and tokenizer.bos_token_id is not None:
93
  model.config.decoder_start_token_id = tokenizer.bos_token_id
94
-
95
  model.to(DEVICE).eval()
96
  _caption_bundle = {"processor": processor, "tokenizer": tokenizer, "model": model}
97
  return _caption_bundle
@@ -100,15 +93,8 @@ def get_story_pipe():
100
  global _story_pipe
101
  if _story_pipe is not None:
102
  return _story_pipe
103
- # Load a fast tokenizer explicitly to kill “slow” warning
104
- story_tok = AutoTokenizer.from_pretrained(STORY_MODEL_ID, use_fast=True, token=HF_TOKEN)
105
- _story_pipe = hf_pipeline(
106
- "text2text-generation",
107
- model=STORY_MODEL_ID,
108
- tokenizer=story_tok,
109
- device_map="auto", # lets HF place layers smartly; will still run CPU if no GPU
110
- # Do NOT pass torch_dtype here (deprecated in some paths). We'll rely on device_map.
111
- )
112
  return _story_pipe
113
 
114
  def get_img2img_pipe():
@@ -117,8 +103,8 @@ def get_img2img_pipe():
117
  return _img2img_pipe
118
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
119
  IMG2IMG_MODEL_ID,
120
- dtype=DTYPE, # <-- modern arg (fixes torch_dtype deprecation)
121
- safety_checker=None, # flip to enable if you want
122
  requires_safety_checker=False,
123
  use_safetensors=True,
124
  )
@@ -130,12 +116,11 @@ def get_img2img_pipe():
130
  _img2img_pipe = pipe
131
  return _img2img_pipe
132
 
133
- # ----------------- Ops -----------------
134
  @torch.inference_mode()
135
  def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4) -> str:
136
  bundle = get_caption_bundle()
137
  proc, tok, mdl = bundle["processor"], bundle["tokenizer"], bundle["model"]
138
- # Let processor handle size; accepts any input resolution
139
  pv = proc(image.convert("RGB"), return_tensors="pt").pixel_values.to(DEVICE)
140
  out = mdl.generate(
141
  pixel_values=pv,
@@ -149,14 +134,15 @@ def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4)
149
  def op_story(
150
  image: Image.Image,
151
  num_sentences: int = 5,
152
- max_new_tokens: int = 220, # enough headroom
153
- min_new_tokens: int = 80, # force >= ~80 tokens to discourage 1-line outputs
154
  temperature: float = 0.9,
155
  top_p: float = 0.92,
156
  no_repeat_ngram_size: int = 3,
157
  ) -> str:
158
- # Ground with the caption (keeps story on-topic)
159
  caption = op_caption(image)
 
160
  prompt = (
161
  f"Write exactly {num_sentences} sentences based on this image description. "
162
  "Use vivid sensory details. No title, no lists, no bullet points, no numbered lines, no dialogue.\n"
@@ -169,20 +155,23 @@ def op_story(
169
  do_sample=True,
170
  temperature=temperature,
171
  top_p=top_p,
172
- min_new_tokens=min_new_tokens, # key to prevent early stop
173
  max_new_tokens=max_new_tokens,
174
  no_repeat_ngram_size=no_repeat_ngram_size,
175
  num_return_sequences=1,
176
  )
177
  text = out[0]["generated_text"].strip()
178
 
179
- # Final safety belt: clamp to exactly N sentences
 
180
  sents = re.split(r'(?<=[.!?])\s+', text)
181
  sents = [s.strip() for s in sents if s.strip()]
182
  if len(sents) >= num_sentences:
183
  text = " ".join(sents[:num_sentences])
184
  return text
185
 
 
 
186
  @torch.inference_mode()
187
  def op_cartoon(image: Image.Image, steps=DEFAULT_STEPS, guidance=DEFAULT_GUIDANCE, strength=DEFAULT_STRENGTH, seed: Optional[int]=None):
188
  img = _resize_max(image.convert("RGB"))
@@ -201,15 +190,17 @@ def op_cartoon(image: Image.Image, steps=DEFAULT_STEPS, guidance=DEFAULT_GUIDANC
201
  )
202
  return result.images[0]
203
 
204
- # ----------------- Router -----------------
205
  def route_command(command: str) -> str:
206
  c = (command or "").lower()
207
  if any(k in c for k in ["cartoon", "sketch", "comic", "anime", "illustration"]):
208
  return "cartoon"
209
  if any(k in c for k in ["story", "poem", "narrative", "write"]):
210
  return "story"
211
- return "caption" # default / describe”, caption”, etc.
 
212
 
 
213
  def run(image: Image.Image, command: str, steps: int, guidance: float, strength: float, seed: str):
214
  if image is None:
215
  raise gr.Error("Upload an image.")
@@ -218,25 +209,18 @@ def run(image: Image.Image, command: str, steps: int, guidance: float, strength:
218
  img = op_cartoon(image, steps=steps, guidance=guidance, strength=strength, seed=int(seed) if seed else None)
219
  return None, img, f"Mode: cartoon ({steps} steps, guidance {guidance}, strength {strength}, seed {seed or 'None'})"
220
  elif mode == "story":
221
- n = parse_num_sentences(command, default=5)
222
- txt = op_story(image, num_sentences=n)
223
- return txt, None, f"Mode: story ({n} sentences)"
224
  else:
225
  txt = op_caption(image)
226
  return txt, None, "Mode: caption"
227
 
228
- # ----------------- Gradio UI -----------------
229
  with gr.Blocks(css="footer {visibility:hidden}") as demo:
230
  gr.Markdown("# Image Command Router — describe • cartoonize • write a story")
231
  with gr.Row():
232
  with gr.Column():
233
  inp_img = gr.Image(type="pil", label="Image")
234
- inp_cmd = gr.Textbox(
235
- label="Command",
236
- placeholder='e.g., "describe the photo", "make the photo look like a cartoon", "write a 5 sentence story about the image"',
237
- lines=2,
238
- value="describe the photo"
239
- )
240
  with gr.Accordion("Advanced (cartoon mode)", open=False):
241
  steps = gr.Slider(1, 75, value=DEFAULT_STEPS, step=1, label="Steps")
242
  guidance = gr.Slider(0.0, 15.0, value=DEFAULT_GUIDANCE, step=0.1, label="Guidance (CFG)")
 
1
+ # app.py — Multimodal router: one image input + freeform command -> text OR image output
2
+ # Commands (examples):
3
+ # "describe the photo" -> text caption
4
+ # "write a story about the image" -> text story
5
+ # "make the photo look like a cartoon" -> image stylization
6
+ #
7
+ # Dependencies / requirements.txt:
8
  # pip install -q gradio transformers diffusers accelerate torch safetensors pillow
9
+
10
  import os
11
  import re
12
+ import random
13
+ from typing import Optional, Tuple
14
 
15
  import torch
16
  import gradio as gr
17
  from PIL import Image
18
 
19
+ # ---- Transformers: caption + story ----
20
  from transformers import (
21
  VisionEncoderDecoderModel,
22
  AutoImageProcessor,
 
24
  pipeline as hf_pipeline,
25
  )
26
 
27
+ # ---- Diffusers: image-to-image stylization ----
28
  from diffusers import StableDiffusionImg2ImgPipeline
29
 
30
+ # ------------- Config -------------
31
  CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "nlpconnect/vit-gpt2-image-captioning")
32
+ STORY_MODEL_ID = os.getenv("STORY_MODEL_ID", "google/flan-t5-large") # light-ish; ok stories
 
33
  IMG2IMG_MODEL_ID = os.getenv("IMG2IMG_MODEL_ID", "stabilityai/stable-diffusion-2-1")
34
 
35
+ MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "768")) # clamp big uploads to save VRAM
36
  DEFAULT_STEPS = int(os.getenv("STEPS", "30"))
37
  DEFAULT_GUIDANCE = float(os.getenv("GUIDANCE", "7.5"))
38
+ DEFAULT_STRENGTH = float(os.getenv("STRENGTH", "0.6")) # 0..1 (higher = more stylized, less like original)
39
 
40
  DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
41
+ DTYPE = torch.float16 if (DEVICE == "cuda") else torch.float32
42
 
43
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
44
 
45
+ # ------------- Caches -------------
46
  _caption_bundle = {}
47
  _story_pipe = None
48
  _img2img_pipe = None
49
 
50
+ # ------------- Utils -------------
51
  def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
52
  w, h = img.size
53
  if max(w, h) <= max_side:
 
58
  else:
59
  new_h = max_side
60
  new_w = int(w * (max_side / h))
61
+ return img.resize((new_w // 8 * 8, new_h // 8 * 8), Image.LANCZOS) # multiples of 8 for SD
 
62
 
63
  def _seeded_generator(seed: Optional[int]):
64
  if seed is None or str(seed).strip() == "":
 
70
  dev = "cuda" if DEVICE == "cuda" else "cpu"
71
  return torch.Generator(device=dev).manual_seed(seed)
72
 
73
+ # ------------- Loaders -------------
 
 
 
 
 
 
 
 
 
 
74
  def get_caption_bundle():
75
  global _caption_bundle
76
  if _caption_bundle:
77
  return _caption_bundle
 
78
  processor = AutoImageProcessor.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
79
  tokenizer = AutoTokenizer.from_pretrained(CAPTION_MODEL_ID, use_fast=True, token=HF_TOKEN)
80
  model = VisionEncoderDecoderModel.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
81
+ # GPT2 has no pad by default -> set pad=eos to avoid mask issues
 
82
  if tokenizer.pad_token is None:
83
  tokenizer.pad_token = tokenizer.eos_token
 
84
  model.config.pad_token_id = tokenizer.pad_token_id
85
  model.config.eos_token_id = tokenizer.eos_token_id
86
  if getattr(model.config, "decoder_start_token_id", None) is None and tokenizer.bos_token_id is not None:
87
  model.config.decoder_start_token_id = tokenizer.bos_token_id
 
88
  model.to(DEVICE).eval()
89
  _caption_bundle = {"processor": processor, "tokenizer": tokenizer, "model": model}
90
  return _caption_bundle
 
93
  global _story_pipe
94
  if _story_pipe is not None:
95
  return _story_pipe
96
+ # Flan-T5 works with text2text-generation
97
+ _story_pipe = hf_pipeline("text2text-generation", model=STORY_MODEL_ID, device_map="auto", model_kwargs={"torch_dtype": DTYPE})
 
 
 
 
 
 
 
98
  return _story_pipe
99
 
100
  def get_img2img_pipe():
 
103
  return _img2img_pipe
104
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
105
  IMG2IMG_MODEL_ID,
106
+ torch_dtype=DTYPE,
107
+ safety_checker=None, # flip to enable safety if you prefer
108
  requires_safety_checker=False,
109
  use_safetensors=True,
110
  )
 
116
  _img2img_pipe = pipe
117
  return _img2img_pipe
118
 
119
+ # ------------- Ops -------------
120
  @torch.inference_mode()
121
  def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4) -> str:
122
  bundle = get_caption_bundle()
123
  proc, tok, mdl = bundle["processor"], bundle["tokenizer"], bundle["model"]
 
124
  pv = proc(image.convert("RGB"), return_tensors="pt").pixel_values.to(DEVICE)
125
  out = mdl.generate(
126
  pixel_values=pv,
 
134
  def op_story(
135
  image: Image.Image,
136
  num_sentences: int = 5,
137
+ max_new_tokens: int = 220, # allow enough room
138
+ min_new_tokens: int = 80, # force >= ~80 tokens (~5 sentences)
139
  temperature: float = 0.9,
140
  top_p: float = 0.92,
141
  no_repeat_ngram_size: int = 3,
142
  ) -> str:
143
+ # Ground the story with a caption of the image
144
  caption = op_caption(image)
145
+
146
  prompt = (
147
  f"Write exactly {num_sentences} sentences based on this image description. "
148
  "Use vivid sensory details. No title, no lists, no bullet points, no numbered lines, no dialogue.\n"
 
155
  do_sample=True,
156
  temperature=temperature,
157
  top_p=top_p,
158
+ min_new_tokens=min_new_tokens, # <- prevents early stop
159
  max_new_tokens=max_new_tokens,
160
  no_repeat_ngram_size=no_repeat_ngram_size,
161
  num_return_sequences=1,
162
  )
163
  text = out[0]["generated_text"].strip()
164
 
165
+ # Safety belt: hard-trim to exactly N sentences
166
+ import re
167
  sents = re.split(r'(?<=[.!?])\s+', text)
168
  sents = [s.strip() for s in sents if s.strip()]
169
  if len(sents) >= num_sentences:
170
  text = " ".join(sents[:num_sentences])
171
  return text
172
 
173
+
174
+
175
  @torch.inference_mode()
176
  def op_cartoon(image: Image.Image, steps=DEFAULT_STEPS, guidance=DEFAULT_GUIDANCE, strength=DEFAULT_STRENGTH, seed: Optional[int]=None):
177
  img = _resize_max(image.convert("RGB"))
 
190
  )
191
  return result.images[0]
192
 
193
+ # ------------- Router -------------
194
  def route_command(command: str) -> str:
195
  c = (command or "").lower()
196
  if any(k in c for k in ["cartoon", "sketch", "comic", "anime", "illustration"]):
197
  return "cartoon"
198
  if any(k in c for k in ["story", "poem", "narrative", "write"]):
199
  return "story"
200
+ # default / describe / caption / explain
201
+ return "caption"
202
 
203
+ # ------------- Gradio App -------------
204
  def run(image: Image.Image, command: str, steps: int, guidance: float, strength: float, seed: str):
205
  if image is None:
206
  raise gr.Error("Upload an image.")
 
209
  img = op_cartoon(image, steps=steps, guidance=guidance, strength=strength, seed=int(seed) if seed else None)
210
  return None, img, f"Mode: cartoon ({steps} steps, guidance {guidance}, strength {strength}, seed {seed or 'None'})"
211
  elif mode == "story":
212
+ txt = op_story(image)
213
+ return txt, None, "Mode: story"
 
214
  else:
215
  txt = op_caption(image)
216
  return txt, None, "Mode: caption"
217
 
 
218
  with gr.Blocks(css="footer {visibility:hidden}") as demo:
219
  gr.Markdown("# Image Command Router — describe • cartoonize • write a story")
220
  with gr.Row():
221
  with gr.Column():
222
  inp_img = gr.Image(type="pil", label="Image")
223
+ inp_cmd = gr.Textbox(label="Command", placeholder='e.g., "describe the photo", "make the photo look like a cartoon", "write a story about the image"', lines=2, value="describe the photo")
 
 
 
 
 
224
  with gr.Accordion("Advanced (cartoon mode)", open=False):
225
  steps = gr.Slider(1, 75, value=DEFAULT_STEPS, step=1, label="Steps")
226
  guidance = gr.Slider(0.0, 15.0, value=DEFAULT_GUIDANCE, step=0.1, label="Guidance (CFG)")