Fred808 commited on
Commit
71669df
·
verified ·
1 Parent(s): 9835a3a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +498 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import torch
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ import os
8
+ import threading
9
+ import time
10
+ import urllib.parse
11
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form
12
+ from fastapi.responses import JSONResponse
13
+
14
+ app = FastAPI(
15
+ title="Florence-2 Image Captioning Server",
16
+ description="Auto-captions images from middleware server using Florence-2"
17
+ )
18
+ import threading
19
+ import time
20
+ import urllib.parse
21
+
22
+ # Attempt to install flash-attn
23
+ try:
24
+ subprocess.run('pip install flash-attn --no-build-isolation timm einops', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
25
+ except subprocess.CalledProcessError as e:
26
+ print(f"Error installing flash-attn: {e}")
27
+ print("Continuing without flash-attn.")
28
+
29
+ # Determine the device to use
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ # Load the base model and processor
33
+ try:
34
+ vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
35
+ vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
36
+ print("✓ Base model loaded successfully")
37
+ except Exception as e:
38
+ print(f"Error loading base model: {e}")
39
+ vision_language_model_base = None
40
+ vision_language_processor_base = None
41
+
42
+ # Load the large model and processor
43
+ try:
44
+ vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
45
+ vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
46
+ print("✓ Large model loaded successfully")
47
+ except Exception as e:
48
+ print(f"Error loading large model: {e}")
49
+ vision_language_model_large = None
50
+ vision_language_processor_large = None
51
+
52
+ def load_image_from_url(image_url):
53
+ """Load an image from a URL."""
54
+ try:
55
+ response = requests.get(image_url, timeout=30)
56
+ response.raise_for_status()
57
+ image = Image.open(BytesIO(response.content))
58
+ return image.convert('RGB')
59
+ except Exception as e:
60
+ raise ValueError(f"Error loading image from URL: {e}")
61
+
62
+ def process_image_description(model, processor, image):
63
+ """Process an image and generate description using the specified model."""
64
+ if not isinstance(image, Image.Image):
65
+ image = Image.fromarray(image)
66
+
67
+ inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
68
+ with torch.no_grad():
69
+ generated_ids = model.generate(
70
+ input_ids=inputs["input_ids"],
71
+ pixel_values=inputs["pixel_values"],
72
+ max_new_tokens=1024,
73
+ early_stopping=False,
74
+ do_sample=False,
75
+ num_beams=3,
76
+ )
77
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
78
+ processed_description = processor.post_process_generation(
79
+ generated_text,
80
+ task="<MORE_DETAILED_CAPTION>",
81
+ image_size=(image.width, image.height)
82
+ )
83
+ image_description = processed_description["<MORE_DETAILED_CAPTION>"]
84
+ return image_description
85
+
86
+ def describe_image(uploaded_image, model_choice):
87
+ """Generate description from uploaded image."""
88
+ if uploaded_image is None:
89
+ return "Please upload an image."
90
+
91
+ if model_choice == "Florence-2-base":
92
+ if vision_language_model_base is None:
93
+ return "Base model failed to load."
94
+ model = vision_language_model_base
95
+ processor = vision_language_processor_base
96
+ elif model_choice == "Florence-2-large":
97
+ if vision_language_model_large is None:
98
+ return "Large model failed to load."
99
+ model = vision_language_model_large
100
+ processor = vision_language_processor_large
101
+ else:
102
+ return "Invalid model choice."
103
+
104
+ try:
105
+ return process_image_description(model, processor, uploaded_image)
106
+ except Exception as e:
107
+ return f"Error generating caption: {str(e)}"
108
+
109
+ def describe_image_from_url(image_url, model_choice):
110
+ """Generate description from image URL."""
111
+ try:
112
+ if not image_url:
113
+ return {"error": "image_url is required"}
114
+
115
+ if model_choice not in ["Florence-2-base", "Florence-2-large"]:
116
+ return {"error": "Invalid model choice. Use 'Florence-2-base' or 'Florence-2-large'"}
117
+
118
+ # Load image from URL
119
+ image = load_image_from_url(image_url)
120
+
121
+ # Select model and processor
122
+ if model_choice == "Florence-2-base":
123
+ if vision_language_model_base is None:
124
+ return {"error": "Base model not available"}
125
+ model = vision_language_model_base
126
+ processor = vision_language_processor_base
127
+ else:
128
+ if vision_language_model_large is None:
129
+ return {"error": "Large model not available"}
130
+ model = vision_language_model_large
131
+ processor = vision_language_processor_large
132
+
133
+ # Generate caption
134
+ caption = process_image_description(model, processor, image)
135
+
136
+ return {
137
+ "status": "success",
138
+ "model": model_choice,
139
+ "caption": caption,
140
+ "image_size": {"width": image.width, "height": image.height}
141
+ }
142
+
143
+ except Exception as e:
144
+ return {"error": f"Error processing image: {str(e)}"}
145
+
146
+
147
+ IMAGE_SERVER_BASE = os.getenv("IMAGE_SERVER_BASE", " ")
148
+ DATA_COLLECTION_BASE = os.getenv("DATA_COLLECTION_BASE", "https://fred808-flow.hf.space")
149
+ REQUESTER_ID = os.getenv("FLO_REQUESTER_ID", f"florence-2-{os.getpid()}")
150
+ MODEL_CHOICE = os.getenv("FLO_MODEL_CHOICE", "Florence-2-base")
151
+
152
+
153
+ def sanitize_name(name: str, max_len: int = 200) -> str:
154
+ """Sanitize a filename while preserving extension."""
155
+ import re
156
+ name = str(name).strip()
157
+ # replace spaces with underscores
158
+ name = re.sub(r"\s+", "_", name)
159
+ # remove any characters not alphanumeric, dot, dash, or underscore
160
+ name = re.sub(r"[^A-Za-z0-9_.-]", "", name)
161
+ if len(name) > max_len:
162
+ base, ext = os.path.splitext(name)
163
+ name = base[: max_len - len(ext)] + ext
164
+ return name or "file"
165
+
166
+ def _build_download_url(course: str, video: str, frame: str) -> str:
167
+ """Build download URL with proper encoding of all path segments."""
168
+ # The middleware /download endpoint expects the 'file' parameter to be
169
+ # a path relative to the course folder (e.g. "video_name/frame.jpg").
170
+ # Frames live under a "{base_course}_frames" folder.
171
+ base_course = course
172
+ if not base_course.endswith("_frames"):
173
+ course_dir = f"{base_course}_frames"
174
+ else:
175
+ course_dir = base_course
176
+ base_course = course_dir[:-7] # strip _frames for consistency
177
+
178
+ # Sanitize and encode path segments
179
+ safe_course = sanitize_name(course_dir)
180
+ safe_video = sanitize_name(video)
181
+ safe_frame = sanitize_name(frame)
182
+
183
+ file_param = f"{safe_video}/{safe_frame}"
184
+ url = f"{IMAGE_SERVER_BASE.rstrip('/')}/download?course={urllib.parse.quote(safe_course, safe='')}&file={urllib.parse.quote(file_param, safe='')}"
185
+ print(f"[BACKGROUND] Built URL: {url}")
186
+ return url, safe_frame
187
+
188
+
189
+ def _download_bytes(url: str, timeout: int = 30, chunk_size=32768):
190
+ try:
191
+ print(f"[BACKGROUND] Starting download: {url}")
192
+ response = requests.get(url, timeout=timeout, stream=True)
193
+ response.raise_for_status()
194
+ content = BytesIO()
195
+ total_size = int(response.headers.get('content-length', 0))
196
+ print(f"[BACKGROUND] Total size: {total_size} bytes")
197
+
198
+ bytes_read = 0
199
+ for chunk in response.iter_content(chunk_size=chunk_size):
200
+ if chunk:
201
+ content.write(chunk)
202
+ bytes_read += len(chunk)
203
+ if total_size:
204
+ print(f"\rDownloading: {bytes_read}/{total_size} bytes ({(bytes_read/total_size)*100:.1f}%)", end="", flush=True)
205
+ print() # New line after progress
206
+ print(f"[BACKGROUND] Download complete: {bytes_read} bytes")
207
+ return content.getvalue(), response.headers.get('content-type')
208
+ except Exception as e:
209
+ print(f"[BACKGROUND] download failed {url}: {e}")
210
+ return None, None
211
+
212
+
213
+ def _post_submit(caption: str, image_name: str, course: str, image_url: str, image_bytes: bytes):
214
+ submit_url = f"{DATA_COLLECTION_BASE.rstrip('/')}/submit"
215
+ files = {'image': (image_name, image_bytes, 'application/octet-stream')}
216
+ data = {'caption': caption, 'image_name': image_name, 'course': course, 'image_url': image_url}
217
+
218
+ print(f"[BACKGROUND] Submitting to {submit_url}")
219
+ print(f"[BACKGROUND] Image name: {image_name}")
220
+ print(f"[BACKGROUND] Course: {course}")
221
+ print(f"[BACKGROUND] Caption length: {len(caption)} chars")
222
+
223
+ try:
224
+ r = requests.post(submit_url, data=data, files=files, timeout=30)
225
+ print(f"[BACKGROUND] Submit response status: {r.status}")
226
+ try:
227
+ resp = r.json()
228
+ print(f"[BACKGROUND] Submit response JSON: {resp}")
229
+ return r.status_code, resp
230
+ except Exception:
231
+ print(f"[BACKGROUND] Submit response text: {r.text}")
232
+ return r.status_code, r.text
233
+ except Exception as e:
234
+ print(f"[BACKGROUND] Submit POST failed: {e}")
235
+ return None, None
236
+
237
+
238
+ def _release_frame(course: str, video: str, frame: str):
239
+ try:
240
+ release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/frame/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(video, safe='')}/{urllib.parse.quote(frame, safe='')}"
241
+ requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10)
242
+ except Exception as e:
243
+ print(f"[BACKGROUND] release frame failed: {e}")
244
+
245
+
246
+ def _release_course(course: str):
247
+ try:
248
+ release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/course/{urllib.parse.quote(course, safe='')}"
249
+ requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10)
250
+ except Exception as e:
251
+ print(f"[BACKGROUND] release course failed: {e}")
252
+
253
+
254
+ # Background worker implementation
255
+ def background_worker():
256
+ """Background worker that processes images from the middleware server."""
257
+ print("[BACKGROUND] Starting worker, waiting for model...")
258
+
259
+ # Wait for model to be ready
260
+ waited = 0
261
+ while waited < 120:
262
+ if MODEL_CHOICE == "Florence-2-base" and vision_language_model_base:
263
+ break
264
+ elif MODEL_CHOICE == "Florence-2-large" and vision_language_model_large:
265
+ break
266
+ time.sleep(1)
267
+ waited += 1
268
+
269
+ if waited >= 120:
270
+ print("[BACKGROUND] Model not available after timeout")
271
+ return
272
+
273
+ print(f"[BACKGROUND] Model {MODEL_CHOICE} ready, starting processing loop")
274
+
275
+ while True:
276
+ try:
277
+ # Get next course
278
+ courses_url = f"{IMAGE_SERVER_BASE}/courses"
279
+ print(f"[BACKGROUND] Fetching courses from {courses_url}")
280
+
281
+ try:
282
+ r = requests.get(courses_url, timeout=15)
283
+ r.raise_for_status()
284
+ courses_data = r.json()
285
+
286
+ if not courses_data.get('courses'):
287
+ print("[BACKGROUND] No courses found, waiting...")
288
+ time.sleep(3)
289
+ continue
290
+
291
+ # Get first course
292
+ course_entry = courses_data['courses'][0]
293
+ if isinstance(course_entry, dict):
294
+ course = course_entry.get('course_folder')
295
+ else:
296
+ course = str(course_entry)
297
+
298
+ if not course:
299
+ print("[BACKGROUND] Invalid course entry")
300
+ time.sleep(2)
301
+ continue
302
+
303
+ print(f"[BACKGROUND] Processing course: {course}")
304
+
305
+ # Get images list
306
+ images_url = f"{IMAGE_SERVER_BASE}/images/{urllib.parse.quote(course, safe='')}"
307
+ r = requests.get(images_url, timeout=15)
308
+ r.raise_for_status()
309
+ images_data = r.json()
310
+
311
+ if isinstance(images_data, dict):
312
+ image_list = images_data.get('images', [])
313
+ else:
314
+ image_list = images_data
315
+
316
+ if not image_list:
317
+ print(f"[BACKGROUND] No images found for course {course}")
318
+ time.sleep(2)
319
+ continue
320
+
321
+ print(f"[BACKGROUND] Found {len(image_list)} images")
322
+
323
+ # Process images
324
+ for img_entry in image_list:
325
+ try:
326
+ # Extract filename and metadata
327
+ if isinstance(img_entry, dict):
328
+ filename = img_entry.get('filename')
329
+ if not filename:
330
+ continue
331
+ else:
332
+ filename = str(img_entry)
333
+
334
+ # Download image
335
+ download_url = f"{IMAGE_SERVER_BASE}/images/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(filename, safe='')}"
336
+ print(f"[BACKGROUND] Downloading {download_url}")
337
+
338
+ img_bytes, _ = _download_bytes(download_url)
339
+ if not img_bytes:
340
+ print(f"[BACKGROUND] Failed to download {filename}")
341
+ continue
342
+
343
+ # Process with Florence
344
+ try:
345
+ pil_img = Image.open(BytesIO(img_bytes)).convert('RGB')
346
+
347
+ if MODEL_CHOICE == "Florence-2-base":
348
+ model = vision_language_model_base
349
+ processor = vision_language_processor_base
350
+ else:
351
+ model = vision_language_model_large
352
+ processor = vision_language_processor_large
353
+
354
+ print(f"[BACKGROUND] Generating caption for {filename}")
355
+ caption = process_image_description(model, processor, pil_img)
356
+ print(f"[BACKGROUND] Generated caption for {filename}:")
357
+ print("-" * 40)
358
+ print(caption)
359
+ print("-" * 40)
360
+
361
+ # Submit result
362
+ print(f"[BACKGROUND] Submitting caption to {DATA_COLLECTION_BASE}/submit")
363
+ status, resp = _post_submit(caption, filename, course, download_url, img_bytes)
364
+ if status and status < 400:
365
+ print(f"[BACKGROUND] Successfully submitted {filename} (status={status})")
366
+ if resp:
367
+ print(f"[BACKGROUND] Response: {resp}")
368
+ else:
369
+ print(f"[BACKGROUND] Failed to submit {filename}: status={status}, response={resp}")
370
+
371
+ except Exception as e:
372
+ print(f"[BACKGROUND] Error processing {filename}: {e}")
373
+ continue
374
+ finally:
375
+ # Clean up
376
+ if 'pil_img' in locals():
377
+ del pil_img
378
+ if 'img_bytes' in locals():
379
+ del img_bytes
380
+
381
+ time.sleep(0.5) # Small delay between images
382
+
383
+ except Exception as e:
384
+ print(f"[BACKGROUND] Error in image loop: {e}")
385
+ continue
386
+
387
+ print(f"[BACKGROUND] Completed course {course}")
388
+ time.sleep(1)
389
+
390
+ except Exception as e:
391
+ print(f"[BACKGROUND] Error in course loop: {e}")
392
+ time.sleep(5)
393
+ continue
394
+
395
+ except Exception as e:
396
+ print(f"[BACKGROUND] Main loop error: {e}")
397
+ time.sleep(5)
398
+
399
+
400
+ def _start_worker_thread():
401
+ """Start the background worker thread."""
402
+ t = threading.Thread(target=background_worker, daemon=True)
403
+ t.start()
404
+ return t
405
+
406
+
407
+ # FastAPI endpoints for status/health
408
+ @app.get("/")
409
+ async def root():
410
+ return {
411
+ "name": "Florence-2 Image Captioning Server",
412
+ "status": "running",
413
+ "model_base": vision_language_model_base is not None,
414
+ "model_large": vision_language_model_large is not None,
415
+ "device": device
416
+ }
417
+
418
+ @app.get("/health")
419
+ async def health():
420
+ return {
421
+ "status": "healthy",
422
+ "model_base": vision_language_model_base is not None,
423
+ "model_large": vision_language_model_large is not None,
424
+ "device": device,
425
+ "model_choice": MODEL_CHOICE
426
+ }
427
+
428
+ # Start background worker thread (daemon) so it doesn't block shutdown
429
+ def _start_worker_thread():
430
+ t = threading.Thread(target=background_worker, daemon=True)
431
+ t.start()
432
+
433
+ # Start background worker when FastAPI starts
434
+ @app.on_event("startup")
435
+ async def startup_event():
436
+ _start_worker_thread()
437
+
438
+
439
+ @app.get("/analyze")
440
+ async def analyze_get(image_url: str = None, model_choice: str = None):
441
+ """Analyze an image by URL. Usage: /analyze?image_url=https://...&model_choice=Florence-2-base"""
442
+ try:
443
+ mc = model_choice or MODEL_CHOICE
444
+ if image_url:
445
+ result = describe_image_from_url(image_url, mc)
446
+ if isinstance(result, dict) and result.get("status") == "success":
447
+ return JSONResponse(content={"success": True, "caption": result.get("caption"), "image_size": result.get("image_size")})
448
+ else:
449
+ return JSONResponse(status_code=400, content={"success": False, "error": result})
450
+ else:
451
+ raise HTTPException(status_code=400, detail="image_url query parameter is required")
452
+ except HTTPException:
453
+ raise
454
+ except Exception as e:
455
+ return JSONResponse(status_code=500, content={"success": False, "error": str(e)})
456
+
457
+
458
+ @app.post("/analyze")
459
+ async def analyze_post(file: UploadFile = File(None), model_choice: str = Form(None)):
460
+ """Analyze an uploaded image (multipart/form-data). Returns caption JSON."""
461
+ try:
462
+ mc = model_choice or MODEL_CHOICE
463
+ if file is None:
464
+ raise HTTPException(status_code=400, detail="file is required")
465
+
466
+ content = await file.read()
467
+ try:
468
+ pil_img = Image.open(BytesIO(content)).convert('RGB')
469
+ except Exception as e:
470
+ raise HTTPException(status_code=400, detail=f"Failed to read uploaded image: {e}")
471
+
472
+ # Choose model
473
+ if mc == "Florence-2-large":
474
+ if vision_language_model_large is None:
475
+ raise HTTPException(status_code=503, detail="Base model not loaded")
476
+ model = vision_language_model_large
477
+ processor = vision_language_processor_large
478
+ else:
479
+ if vision_language_model_large is None:
480
+ raise HTTPException(status_code=503, detail="Large model not loaded")
481
+ model = vision_language_model_large
482
+ processor = vision_language_processor_large
483
+
484
+ caption = process_image_description(model, processor, pil_img)
485
+ return JSONResponse(content={"success": True, "caption": caption})
486
+
487
+ except HTTPException:
488
+ raise
489
+ except Exception as e:
490
+ return JSONResponse(status_code=500, content={"success": False, "error": str(e)})
491
+
492
+ # Get the port from environment variable (for Hugging Face Spaces)
493
+ port = int(os.environ.get("PORT", 7860))
494
+
495
+ # Launch FastAPI with uvicorn when run directly
496
+ if __name__ == "__main__":
497
+ import uvicorn
498
+ uvicorn.run(app, host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.48.0
2
+ timm
3
+ einops
4
+ pillow
5
+ hf_transfer