achase25 commited on
Commit
787d948
·
verified ·
1 Parent(s): d5419d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — Image+Command router: "describe photo" (caption), "write a story" (text), "make it a cartoon" (img2img)
2
+ # Deps:
3
+ # pip install -q gradio transformers diffusers accelerate torch safetensors pillow
4
+ import os
5
+ import re
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import gradio as gr
10
+ from PIL import Image
11
+
12
+ from transformers import (
13
+ VisionEncoderDecoderModel,
14
+ AutoImageProcessor,
15
+ AutoTokenizer,
16
+ pipeline as hf_pipeline,
17
+ )
18
+
19
+ from diffusers import StableDiffusionImg2ImgPipeline
20
+
21
+ # ----------------- Config -----------------
22
+ CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "nlpconnect/vit-gpt2-image-captioning")
23
+ # For longer/better stories you can set: google/flan-t5-xl (needs ~10–12GB VRAM) or google/flan-ul2 (heavy)
24
+ STORY_MODEL_ID = os.getenv("STORY_MODEL_ID", "google/flan-t5-large")
25
+ IMG2IMG_MODEL_ID = os.getenv("IMG2IMG_MODEL_ID", "stabilityai/stable-diffusion-2-1")
26
+
27
+ MAX_IMG_SIDE = int(os.getenv("MAX_IMG_SIDE", "768"))
28
+ DEFAULT_STEPS = int(os.getenv("STEPS", "30"))
29
+ DEFAULT_GUIDANCE = float(os.getenv("GUIDANCE", "7.5"))
30
+ DEFAULT_STRENGTH = float(os.getenv("STRENGTH", "0.6"))
31
+
32
+ DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
33
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
34
+
35
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None
36
+
37
+ # ----------------- Caches -----------------
38
+ _caption_bundle = {}
39
+ _story_pipe = None
40
+ _img2img_pipe = None
41
+
42
+ # ----------------- Utils -----------------
43
+ def _resize_max(img: Image.Image, max_side: int = MAX_IMG_SIDE) -> Image.Image:
44
+ w, h = img.size
45
+ if max(w, h) <= max_side:
46
+ return img
47
+ if w >= h:
48
+ new_w = max_side
49
+ new_h = int(h * (max_side / w))
50
+ else:
51
+ new_h = max_side
52
+ new_w = int(w * (max_side / h))
53
+ # Snap to multiples of 8 for SD pipelines
54
+ return img.resize((new_w // 8 * 8, new_h // 8 * 8), Image.LANCZOS)
55
+
56
+ def _seeded_generator(seed: Optional[int]):
57
+ if seed is None or str(seed).strip() == "":
58
+ return None
59
+ try:
60
+ seed = int(seed)
61
+ except Exception:
62
+ return None
63
+ dev = "cuda" if DEVICE == "cuda" else "cpu"
64
+ return torch.Generator(device=dev).manual_seed(seed)
65
+
66
+ def parse_num_sentences(cmd: str, default: int = 5) -> int:
67
+ m = re.search(r"(\d+)\s*(?:sentences?|sentence)", (cmd or "").lower())
68
+ if m:
69
+ try:
70
+ n = int(m.group(1))
71
+ return max(1, min(n, 20)) # keep sane bounds
72
+ except Exception:
73
+ pass
74
+ return default
75
+
76
+ # ----------------- Loaders -----------------
77
+ def get_caption_bundle():
78
+ global _caption_bundle
79
+ if _caption_bundle:
80
+ return _caption_bundle
81
+ # use_fast=True avoids “slow processor/tokenizer” warnings
82
+ processor = AutoImageProcessor.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
83
+ tokenizer = AutoTokenizer.from_pretrained(CAPTION_MODEL_ID, use_fast=True, token=HF_TOKEN)
84
+ model = VisionEncoderDecoderModel.from_pretrained(CAPTION_MODEL_ID, token=HF_TOKEN)
85
+
86
+ # GPT-2 decoders have no pad by default -> set pad=eos; set ids so generate() is happy
87
+ if tokenizer.pad_token is None:
88
+ tokenizer.pad_token = tokenizer.eos_token
89
+ tokenizer.padding_side = "right"
90
+ model.config.pad_token_id = tokenizer.pad_token_id
91
+ model.config.eos_token_id = tokenizer.eos_token_id
92
+ if getattr(model.config, "decoder_start_token_id", None) is None and tokenizer.bos_token_id is not None:
93
+ model.config.decoder_start_token_id = tokenizer.bos_token_id
94
+
95
+ model.to(DEVICE).eval()
96
+ _caption_bundle = {"processor": processor, "tokenizer": tokenizer, "model": model}
97
+ return _caption_bundle
98
+
99
+ def get_story_pipe():
100
+ global _story_pipe
101
+ if _story_pipe is not None:
102
+ return _story_pipe
103
+ # Load a fast tokenizer explicitly to kill “slow” warning
104
+ story_tok = AutoTokenizer.from_pretrained(STORY_MODEL_ID, use_fast=True, token=HF_TOKEN)
105
+ _story_pipe = hf_pipeline(
106
+ "text2text-generation",
107
+ model=STORY_MODEL_ID,
108
+ tokenizer=story_tok,
109
+ device_map="auto", # lets HF place layers smartly; will still run CPU if no GPU
110
+ # Do NOT pass torch_dtype here (deprecated in some paths). We'll rely on device_map.
111
+ )
112
+ return _story_pipe
113
+
114
+ def get_img2img_pipe():
115
+ global _img2img_pipe
116
+ if _img2img_pipe is not None:
117
+ return _img2img_pipe
118
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
119
+ IMG2IMG_MODEL_ID,
120
+ dtype=DTYPE, # <-- modern arg (fixes torch_dtype deprecation)
121
+ safety_checker=None, # flip to enable if you want
122
+ requires_safety_checker=False,
123
+ use_safetensors=True,
124
+ )
125
+ pipe = pipe.to(DEVICE)
126
+ try:
127
+ pipe.enable_xformers_memory_efficient_attention()
128
+ except Exception:
129
+ pass
130
+ _img2img_pipe = pipe
131
+ return _img2img_pipe
132
+
133
+ # ----------------- Ops -----------------
134
+ @torch.inference_mode()
135
+ def op_caption(image: Image.Image, max_new_tokens: int = 32, num_beams: int = 4) -> str:
136
+ bundle = get_caption_bundle()
137
+ proc, tok, mdl = bundle["processor"], bundle["tokenizer"], bundle["model"]
138
+ # Let processor handle size; accepts any input resolution
139
+ pv = proc(image.convert("RGB"), return_tensors="pt").pixel_values.to(DEVICE)
140
+ out = mdl.generate(
141
+ pixel_values=pv,
142
+ max_new_tokens=max_new_tokens,
143
+ num_beams=num_beams,
144
+ pad_token_id=tok.pad_token_id,
145
+ eos_token_id=tok.eos_token_id,
146
+ )
147
+ return tok.decode(out[0], skip_special_tokens=True).strip()
148
+
149
+ def op_story(
150
+ image: Image.Image,
151
+ num_sentences: int = 5,
152
+ max_new_tokens: int = 220, # enough headroom
153
+ min_new_tokens: int = 80, # force >= ~80 tokens to discourage 1-line outputs
154
+ temperature: float = 0.9,
155
+ top_p: float = 0.92,
156
+ no_repeat_ngram_size: int = 3,
157
+ ) -> str:
158
+ # Ground with the caption (keeps story on-topic)
159
+ caption = op_caption(image)
160
+ prompt = (
161
+ f"Write exactly {num_sentences} sentences based on this image description. "
162
+ "Use vivid sensory details. No title, no lists, no bullet points, no numbered lines, no dialogue.\n"
163
+ f"Image description: {caption}\n\nStory:"
164
+ )
165
+
166
+ story_pipe = get_story_pipe()
167
+ out = story_pipe(
168
+ prompt,
169
+ do_sample=True,
170
+ temperature=temperature,
171
+ top_p=top_p,
172
+ min_new_tokens=min_new_tokens, # key to prevent early stop
173
+ max_new_tokens=max_new_tokens,
174
+ no_repeat_ngram_size=no_repeat_ngram_size,
175
+ num_return_sequences=1,
176
+ )
177
+ text = out[0]["generated_text"].strip()
178
+
179
+ # Final safety belt: clamp to exactly N sentences
180
+ sents = re.split(r'(?<=[.!?])\s+', text)
181
+ sents = [s.strip() for s in sents if s.strip()]
182
+ if len(sents) >= num_sentences:
183
+ text = " ".join(sents[:num_sentences])
184
+ return text
185
+
186
+ @torch.inference_mode()
187
+ def op_cartoon(image: Image.Image, steps=DEFAULT_STEPS, guidance=DEFAULT_GUIDANCE, strength=DEFAULT_STRENGTH, seed: Optional[int]=None):
188
+ img = _resize_max(image.convert("RGB"))
189
+ gen = _seeded_generator(seed)
190
+ pipe = get_img2img_pipe()
191
+ prompt = "cartoon, cel-shaded, flat colors, bold outlines, clean lineart, anime style, comic book"
192
+ negative = "photorealistic, blurry, noisy, artifacts, distorted, watermark"
193
+ result = pipe(
194
+ prompt=prompt,
195
+ negative_prompt=negative,
196
+ image=img,
197
+ strength=float(strength),
198
+ guidance_scale=float(guidance),
199
+ num_inference_steps=int(steps),
200
+ generator=gen,
201
+ )
202
+ return result.images[0]
203
+
204
+ # ----------------- Router -----------------
205
+ def route_command(command: str) -> str:
206
+ c = (command or "").lower()
207
+ if any(k in c for k in ["cartoon", "sketch", "comic", "anime", "illustration"]):
208
+ return "cartoon"
209
+ if any(k in c for k in ["story", "poem", "narrative", "write"]):
210
+ return "story"
211
+ return "caption" # default / “describe”, “caption”, etc.
212
+
213
+ def run(image: Image.Image, command: str, steps: int, guidance: float, strength: float, seed: str):
214
+ if image is None:
215
+ raise gr.Error("Upload an image.")
216
+ mode = route_command(command)
217
+ if mode == "cartoon":
218
+ img = op_cartoon(image, steps=steps, guidance=guidance, strength=strength, seed=int(seed) if seed else None)
219
+ return None, img, f"Mode: cartoon ({steps} steps, guidance {guidance}, strength {strength}, seed {seed or 'None'})"
220
+ elif mode == "story":
221
+ n = parse_num_sentences(command, default=5)
222
+ txt = op_story(image, num_sentences=n)
223
+ return txt, None, f"Mode: story ({n} sentences)"
224
+ else:
225
+ txt = op_caption(image)
226
+ return txt, None, "Mode: caption"
227
+
228
+ # ----------------- Gradio UI -----------------
229
+ with gr.Blocks(css="footer {visibility:hidden}") as demo:
230
+ gr.Markdown("# Image Command Router — describe • cartoonize • write a story")
231
+ with gr.Row():
232
+ with gr.Column():
233
+ inp_img = gr.Image(type="pil", label="Image")
234
+ inp_cmd = gr.Textbox(
235
+ label="Command",
236
+ placeholder='e.g., "describe the photo", "make the photo look like a cartoon", "write a 5 sentence story about the image"',
237
+ lines=2,
238
+ value="describe the photo"
239
+ )
240
+ with gr.Accordion("Advanced (cartoon mode)", open=False):
241
+ steps = gr.Slider(1, 75, value=DEFAULT_STEPS, step=1, label="Steps")
242
+ guidance = gr.Slider(0.0, 15.0, value=DEFAULT_GUIDANCE, step=0.1, label="Guidance (CFG)")
243
+ strength = gr.Slider(0.1, 1.0, value=DEFAULT_STRENGTH, step=0.05, label="Strength (how much to change)")
244
+ seed = gr.Textbox(value="", label="Seed (optional int)")
245
+ go = gr.Button("Run", variant="primary")
246
+ with gr.Column():
247
+ out_text = gr.Textbox(label="Text output", lines=10)
248
+ out_image = gr.Image(label="Image output")
249
+ status = gr.Markdown()
250
+ go.click(run, inputs=[inp_img, inp_cmd, steps, guidance, strength, seed], outputs=[out_text, out_image, status], scroll_to_output=True)
251
+
252
+ if __name__ == "__main__":
253
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)