rapsar commited on
Commit
fd24341
·
verified ·
1 Parent(s): 1cb811a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -0
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import time
4
+ import os
5
+ from PIL import Image, ImageOps, ImageDraw
6
+ import numpy as np
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import torch
9
+
10
+ DEFAULT_CANVAS = 64
11
+ DEFAULT_BRUSH = 2
12
+
13
+ def make_blank_canvas(w: int, h: int) -> Image.Image:
14
+ # Grayscale black canvas; ImageEditor will convert to its image_mode
15
+ return Image.new("L", (w, h), 0)
16
+
17
+ def pil_to_rowstring(img: Image.Image) -> str:
18
+ arr = np.array(img.convert("L"), dtype=np.uint8)
19
+ lines = [",".join(map(str, row.tolist())) + ";" for row in arr]
20
+ return "\n".join(lines)
21
+
22
+ def pil_to_binstring(img: Image.Image, thresh: int = 128) -> str:
23
+ arr = np.array(img.convert("L"), dtype=np.uint8)
24
+ mask = (arr >= int(thresh)).astype(np.uint8)
25
+ lines = [",".join(map(str, row.tolist())) + ";" for row in mask]
26
+ return "\n".join(lines)
27
+
28
+ # --- LLM helpers (lazy load per model) ---
29
+ _LLM_CACHE = {} # model_id -> (tokenizer, model)
30
+
31
+ def load_llm(model_id: str):
32
+ if model_id in _LLM_CACHE:
33
+ return _LLM_CACHE[model_id]
34
+
35
+ # Use float16 for GPU, float32 for CPU
36
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
+
38
+ # Load tokenizer
39
+ tok = AutoTokenizer.from_pretrained(model_id)
40
+ if tok.pad_token is None:
41
+ tok.pad_token = tok.eos_token
42
+
43
+ # Load model
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ mdl = AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=dtype,
48
+ device_map="auto" if torch.cuda.is_available() else None,
49
+ trust_remote_code=True
50
+ )
51
+
52
+ if not torch.cuda.is_available():
53
+ mdl = mdl.to(device)
54
+
55
+ _LLM_CACHE[model_id] = (tok, mdl)
56
+ return tok, mdl
57
+
58
+ @spaces.GPU
59
+ def run_llm(prompt: str, max_new_tokens: int = 64, temperature: float = 0.0, model_id: str = "meta-llama/Llama-3.2-1B") -> str:
60
+ try:
61
+ tok, mdl = load_llm(model_id)
62
+
63
+ # Tokenize input
64
+ inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=2048)
65
+ inputs = {k: v.to(next(mdl.parameters()).device) for k, v in inputs.items()}
66
+
67
+ # Generate
68
+ with torch.inference_mode():
69
+ outputs = mdl.generate(
70
+ inputs["input_ids"],
71
+ max_new_tokens=int(max_new_tokens),
72
+ do_sample=(temperature > 0),
73
+ temperature=temperature if temperature > 0 else None,
74
+ top_p=None,
75
+ pad_token_id=tok.eos_token_id,
76
+ eos_token_id=tok.eos_token_id,
77
+ use_cache=True,
78
+ )
79
+
80
+ # Decode only the new tokens
81
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
82
+ text = tok.decode(new_tokens, skip_special_tokens=True)
83
+ return text.strip()
84
+
85
+ except Exception as e:
86
+ return f"[LLM error: {e}]"
87
+
88
+ def csv_single_line(csv_multiline: str) -> str:
89
+ # Remove newlines; keep semicolons as row delimiters
90
+ return (csv_multiline or "").replace("\n", "")
91
+
92
+ def parse_csv_image(s: str, width: int):
93
+ # Parse a semicolon/comma separated string of integers into an L-mode image
94
+ try:
95
+ rows = [r for r in s.strip().split(";") if r != ""]
96
+ parsed_rows = []
97
+ for r in rows:
98
+ nums = []
99
+ for tok in r.split(","):
100
+ tok = ''.join(ch for ch in tok if ch.isdigit())
101
+ if tok == "":
102
+ continue
103
+ v = max(0, min(255, int(tok)))
104
+ nums.append(v)
105
+ if nums:
106
+ # pad/truncate to the canvas width
107
+ if len(nums) < width:
108
+ nums = nums + [0] * (width - len(nums))
109
+ else:
110
+ nums = nums[:width]
111
+ parsed_rows.append(nums)
112
+ if not parsed_rows:
113
+ return None
114
+ arr = np.array(parsed_rows, dtype=np.uint8)
115
+ return Image.fromarray(arr, mode="L")
116
+ except Exception:
117
+ return None
118
+
119
+ def apply_settings(canvas_px):
120
+ w = int(canvas_px)
121
+ h = int(canvas_px)
122
+ # Recreate the editor with consistent config and a fresh blank canvas to enforce size
123
+ return gr.ImageEditor(
124
+ canvas_size=(w, h),
125
+ value=make_blank_canvas(w, h),
126
+ image_mode="RGBA",
127
+ brush=gr.Brush(
128
+ default_size=DEFAULT_BRUSH,
129
+ colors=["black", "#404040", "#808080", "#C0C0C0", "white"],
130
+ default_color="white", # white stands out on the new black canvas
131
+ color_mode="fixed",
132
+ ),
133
+ eraser=gr.Eraser(default_size=1),
134
+ transforms=("crop", "resize"),
135
+ height=500,
136
+ )
137
+
138
+ # Process uploaded image: resize to canvas width, grayscale, update editor + preview
139
+ def process_upload(im, canvas_px, scale, invert, binarize, bin_thresh):
140
+ if not im or im.get("background") is None:
141
+ return None, None, None
142
+ bg = im["background"]
143
+ img = Image.fromarray(bg)
144
+ # convert to grayscale
145
+ img = img.convert("L")
146
+ # resize to canvas width, keep aspect
147
+ w, h = img.size
148
+ target_w = int(canvas_px) if canvas_px is not None else w
149
+ if target_w <= 0:
150
+ target_w = w
151
+ target_h = max(1, round(h * target_w / max(1, w)))
152
+ resized = img.resize((target_w, target_h), Image.LANCZOS)
153
+
154
+ # Create a canvas-sized grayscale image and paste the resized image at (0,0)
155
+ canvas_gray = Image.new("L", (target_w, target_w), 0)
156
+ canvas_gray.paste(resized, (0, 0))
157
+
158
+ # Editor value (canvas-size, grayscale)
159
+ editor_value = canvas_gray
160
+
161
+ # Preview & CSV: start from canvas_gray, optionally invert, then
162
+ # - CSV from canvas-sized image
163
+ # - Preview from upscaled image
164
+ base_for_text = canvas_gray
165
+ if invert:
166
+ base_for_text = ImageOps.invert(base_for_text)
167
+ if bool(binarize):
168
+ text = pil_to_binstring(base_for_text, bin_thresh)
169
+ else:
170
+ text = pil_to_rowstring(base_for_text)
171
+
172
+ s = max(1, int(scale) if scale is not None else 8)
173
+ preview = base_for_text.resize((base_for_text.width * s, base_for_text.height * s), Image.NEAREST)
174
+ return editor_value, preview, text
175
+
176
+ def make_preview(im, scale, invert, binarize, bin_thresh):
177
+ if im is None or im.get("composite") is None:
178
+ return None, ""
179
+ arr = im["composite"]
180
+ base = Image.fromarray(arr).convert("L") # canvas-sized grayscale
181
+ # Apply inversion for both preview and CSV (CSV stays canvas-sized)
182
+ base_for_text = ImageOps.invert(base) if invert else base
183
+ if bool(binarize):
184
+ text = pil_to_binstring(base_for_text, bin_thresh)
185
+ else:
186
+ text = pil_to_rowstring(base_for_text)
187
+
188
+ # Preview is the upscaled version of base_for_text
189
+ s = max(1, int(scale) if scale is not None else 8)
190
+ preview = base_for_text.resize((base_for_text.width * s, base_for_text.height * s), Image.NEAREST)
191
+ return preview, text
192
+
193
+ def extrapolate_with_llm(csv_text, canvas_px, out_rows, model_id):
194
+ one_line = csv_single_line(csv_text)
195
+ # Count how many rows come from the input (non-empty segments ending with ';')
196
+ input_rows_count = len([r for r in (one_line or "").split(";") if r.strip()])
197
+ try:
198
+ width = int(canvas_px)
199
+ except Exception:
200
+ width = DEFAULT_CANVAS
201
+ max_tokens = int(out_rows) * width * 2
202
+ prompt = one_line # feed the single-line CSV directly
203
+
204
+ gen = run_llm(prompt, int(max_tokens), model_id=model_id)
205
+
206
+ if gen.startswith("[LLM error:"):
207
+ return gen, None
208
+
209
+ # Parse INPUT + OUTPUT together; ';' marks end-of-row
210
+ combined = (one_line or "") + (gen or "")
211
+ rows = [r for r in combined.split(";") if r.strip()]
212
+
213
+ parsed = []
214
+ max_w = 0
215
+ for r in rows:
216
+ vals = []
217
+ for tok in r.split(","):
218
+ tok = tok.strip()
219
+ if not tok:
220
+ continue
221
+ try:
222
+ v = int(float(tok))
223
+ except Exception:
224
+ continue
225
+ # clamp to 0-255 grayscale
226
+ if v < 0: v = 0
227
+ if v > 255: v = 255
228
+ vals.append(v)
229
+ if vals:
230
+ parsed.append(vals)
231
+ if len(vals) > max_w:
232
+ max_w = len(vals)
233
+
234
+ if not parsed:
235
+ return gen, None
236
+
237
+ # Pad rows to the full width so we can render the full rectangular image
238
+ arr_rows = []
239
+ for vals in parsed:
240
+ if len(vals) < max_w:
241
+ vals = vals + [0] * (max_w - len(vals))
242
+ else:
243
+ vals = vals[:max_w]
244
+ arr_rows.append(vals)
245
+
246
+ arr = np.array(arr_rows, dtype=np.uint8)
247
+ # If the array is binary (only 0 and 1), rescale to 0-255
248
+ if set(np.unique(arr).tolist()).issubset({0, 1}):
249
+ arr = arr * 255
250
+ img = Image.fromarray(arr, mode="L")
251
+
252
+ # Resize to width=512, preserve aspect ratio
253
+ target_w = 512
254
+ orig_w, orig_h = img.size
255
+ target_h = max(1, round(orig_h * target_w / max(1, orig_w)))
256
+ img = img.resize((target_w, target_h), Image.NEAREST)
257
+
258
+ # Draw a thin red separator line at the boundary between input and output rows
259
+ # Map input row index from original height to resized height
260
+ if input_rows_count > 0 and orig_h > 0:
261
+ y = round(input_rows_count * target_h / orig_h)
262
+ y = max(0, min(target_h - 1, y))
263
+ img_rgb = img.convert("RGB")
264
+ draw = ImageDraw.Draw(img_rgb)
265
+ draw.line([(0, y), (img_rgb.width - 1, y)], fill=(255, 0, 0), width=1)
266
+ img = img_rgb
267
+
268
+ display_text = (gen or "").replace(";", ";\n")
269
+ return display_text, img
270
+
271
+ # Custom theme
272
+ theme = gr.Theme.from_hub('gstaff/xkcd')
273
+ theme.set(block_background_fill="#7ffacd8e")
274
+
275
+ with gr.Blocks(theme=theme, title="Image Extrapolation with LLMs") as demo:
276
+ gr.Markdown("### Extrapolate images with LLMs")
277
+ gr.Markdown("Draw or upload an image, and let an LLM continue the pattern!")
278
+
279
+ with gr.Row():
280
+ with gr.Column(scale=1, min_width=220):
281
+ canvas_px = gr.Slider(32, 128, value=DEFAULT_CANVAS, step=1, label="Canvas size (px)")
282
+ preview_scale = gr.Slider(1, 16, value=8, step=1, label="Preview scale (×)")
283
+ invert_preview = gr.Checkbox(value=False, label="Invert preview")
284
+
285
+ with gr.Accordion("Binarize", open=False):
286
+ binarize_csv = gr.Checkbox(value=False, label="Turn 0-255 into 0/1")
287
+ bin_thresh = gr.Slider(0, 255, value=128, step=1, label="Threshold")
288
+
289
+ out_rows_default_value = 3
290
+ out_rows = gr.Slider(1, 16, value=out_rows_default_value, step=1, label="Number of output rows")
291
+ llm_choice = gr.Dropdown(
292
+ label="LLM model",
293
+ choices=[
294
+ "meta-llama/Llama-3.2-1B",
295
+ "meta-llama/Llama-3.2-3B",
296
+ "HuggingFaceTB/SmolLM2-1.7B",
297
+ "HuggingFaceTB/SmolLM2-7B",
298
+ ],
299
+ value="meta-llama/Llama-3.2-1B",
300
+ )
301
+ out_tokens_info = gr.Markdown(f"**Output tokens:** {DEFAULT_CANVAS * out_rows_default_value * 2}")
302
+
303
+ with gr.Column(scale=4):
304
+ im = gr.ImageEditor(
305
+ type="numpy",
306
+ canvas_size=(DEFAULT_CANVAS, DEFAULT_CANVAS),
307
+ image_mode="RGBA",
308
+ brush=gr.Brush(
309
+ default_size=DEFAULT_BRUSH,
310
+ colors=["black", "#404040", "#808080", "#C0C0C0", "white"],
311
+ default_color="black",
312
+ color_mode="fixed",
313
+ ),
314
+ eraser=gr.Eraser(default_size=1),
315
+ transforms=("crop", "resize"),
316
+ height=500,
317
+ )
318
+ im_preview = gr.Image(height=512, label="Preview (scaled)")
319
+
320
+ preview_text = gr.Textbox(
321
+ label="Preview as CSV (rows end with ';')",
322
+ lines=12,
323
+ interactive=False,
324
+ show_copy_button=True,
325
+ max_lines=5
326
+ )
327
+
328
+ # Helper to update button label
329
+ def update_button_label(model_id):
330
+ return f"Extrapolate with LLM ({model_id.split('/')[-1]})"
331
+
332
+ extrap_btn = gr.Button(
333
+ value="Extrapolate with LLM (Llama-3.2-1B)",
334
+ variant="primary"
335
+ )
336
+
337
+ llm_text = gr.Textbox(
338
+ label="LLM output (single-line CSV)",
339
+ lines=6,
340
+ interactive=False,
341
+ show_copy_button=True
342
+ )
343
+ llm_image = gr.Image(label="LLM parsed image", height=512)
344
+
345
+ # Event handlers
346
+ canvas_px.change(apply_settings, inputs=[canvas_px], outputs=im)
347
+ canvas_px.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text])
348
+
349
+ im.upload(process_upload, inputs=[im, canvas_px, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im, im_preview, preview_text])
350
+ im.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text], show_progress="hidden")
351
+ preview_scale.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text])
352
+ invert_preview.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text])
353
+ binarize_csv.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text])
354
+ bin_thresh.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text])
355
+
356
+ extrap_btn.click(extrapolate_with_llm, inputs=[preview_text, canvas_px, out_rows, llm_choice], outputs=[llm_text, llm_image])
357
+
358
+ # Update button label dynamically when LLM model changes
359
+ llm_choice.change(update_button_label, inputs=[llm_choice], outputs=[extrap_btn])
360
+
361
+ def update_tokens(out_rows, canvas_px):
362
+ try:
363
+ width = int(canvas_px)
364
+ except Exception:
365
+ width = DEFAULT_CANVAS
366
+ tokens = int(out_rows) * width * 2
367
+ return f"**Output tokens:** {tokens}"
368
+
369
+ out_rows.change(update_tokens, inputs=[out_rows, canvas_px], outputs=out_tokens_info)
370
+ canvas_px.change(update_tokens, inputs=[out_rows, canvas_px], outputs=out_tokens_info)
371
+
372
+ demo.load(update_tokens, inputs=[out_rows, canvas_px], outputs=out_tokens_info)
373
+
374
+ if __name__ == "__main__":
375
+ demo.launch()