File size: 18,441 Bytes
17008d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
import subprocess
import torch
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoProcessor, AutoModelForCausalLM
import os
import threading
import time
import urllib.parse
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import JSONResponse

app = FastAPI(
    title="Florence-2 Image Captioning Server",
    description="Auto-captions images from middleware server using Florence-2"
)
import threading
import time
import urllib.parse

# Attempt to install flash-attn
try:
    subprocess.run('pip install flash-attn --no-build-isolation timm einops', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
except subprocess.CalledProcessError as e:
    print(f"Error installing flash-attn: {e}")
    print("Continuing without flash-attn.")

# Determine the device to use
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Florence-2-large model and processor
try:
    vision_language_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
    vision_language_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
    print("✓ Florence-2-large model loaded successfully")
except Exception as e:
    print(f"Error loading Florence-2-large model: {e}")
    vision_language_model = None
    vision_language_processor = None

def load_image_from_url(image_url):
    """Load an image from a URL."""
    try:
        response = requests.get(image_url, timeout=30)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return image.convert('RGB')
    except Exception as e:
        raise ValueError(f"Error loading image from URL: {e}")

def process_image_description(model, processor, image):
    """Process an image and generate description using the specified model."""
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)

    inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
        )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    processed_description = processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(image.width, image.height)
    )
    image_description = processed_description["<MORE_DETAILED_CAPTION>"]
    return image_description

def describe_image(uploaded_image, model_choice):
    """Generate description from uploaded image."""
    if uploaded_image is None:
        return "Please upload an image."

    if vision_language_model is None:
        return "Florence-2-large model failed to load."
    
    model = vision_language_model
    processor = vision_language_processor

    try:
        return process_image_description(model, processor, uploaded_image)
    except Exception as e:
        return f"Error generating caption: {str(e)}"

def describe_image_from_url(image_url, model_choice=None):
    """Generate description from image URL."""
    try:
        if not image_url:
            return {"error": "image_url is required"}
        
        if vision_language_model is None:
            return {"error": "Florence-2-large model not available"}
            
        # Load image from URL
        image = load_image_from_url(image_url)
        
        # Use the loaded large model
        model = vision_language_model
        processor = vision_language_processor
        
        # Generate caption
        caption = process_image_description(model, processor, image)
        
        return {
            "status": "success",
            "model": model_choice,
            "caption": caption,
            "image_size": {"width": image.width, "height": image.height}
        }
        
    except Exception as e:
        return {"error": f"Error processing image: {str(e)}"}


IMAGE_SERVER_BASE = os.getenv("IMAGE_SERVER_BASE", " ")
DATA_COLLECTION_BASE = os.getenv("DATA_COLLECTION_BASE", "https://fred808-flow.hf.space")
REQUESTER_ID = os.getenv("FLO_REQUESTER_ID", f"florence-2-{os.getpid()}")
MODEL_CHOICE = "Florence-2-large"  # Always use large model


def sanitize_name(name: str, max_len: int = 200) -> str:
    """Sanitize a filename while preserving extension."""
    import re
    name = str(name).strip()
    # replace spaces with underscores
    name = re.sub(r"\s+", "_", name)
    # remove any characters not alphanumeric, dot, dash, or underscore
    name = re.sub(r"[^A-Za-z0-9_.-]", "", name)
    if len(name) > max_len:
        base, ext = os.path.splitext(name)
        name = base[: max_len - len(ext)] + ext
    return name or "file"

def _build_download_url(course: str, video: str, frame: str) -> str:
    """Build download URL with proper encoding of all path segments."""
    # The middleware /download endpoint expects the 'file' parameter to be
    # a path relative to the course folder (e.g. "video_name/frame.jpg").
    # Frames live under a "{base_course}_frames" folder.
    base_course = course
    if not base_course.endswith("_frames"):
        course_dir = f"{base_course}_frames"
    else:
        course_dir = base_course
        base_course = course_dir[:-7]  # strip _frames for consistency

    # Sanitize and encode path segments
    safe_course = sanitize_name(course_dir)
    safe_video = sanitize_name(video)
    safe_frame = sanitize_name(frame)
    
    file_param = f"{safe_video}/{safe_frame}"
    url = f"{IMAGE_SERVER_BASE.rstrip('/')}/download?course={urllib.parse.quote(safe_course, safe='')}&file={urllib.parse.quote(file_param, safe='')}"
    print(f"[BACKGROUND] Built URL: {url}")
    return url, safe_frame


def _download_bytes(url: str, timeout: int = 30, chunk_size=32768):
    try:
        print(f"[BACKGROUND] Starting download: {url}")
        response = requests.get(url, timeout=timeout, stream=True)
        response.raise_for_status()
        content = BytesIO()
        total_size = int(response.headers.get('content-length', 0))
        print(f"[BACKGROUND] Total size: {total_size} bytes")
        
        bytes_read = 0
        for chunk in response.iter_content(chunk_size=chunk_size):
            if chunk:
                content.write(chunk)
                bytes_read += len(chunk)
                if total_size:
                    print(f"\rDownloading: {bytes_read}/{total_size} bytes ({(bytes_read/total_size)*100:.1f}%)", end="", flush=True)
        print()  # New line after progress
        print(f"[BACKGROUND] Download complete: {bytes_read} bytes")
        return content.getvalue(), response.headers.get('content-type')
    except Exception as e:
        print(f"[BACKGROUND] download failed {url}: {e}")
        return None, None


def _post_submit(caption: str, image_name: str, course: str, image_url: str, image_bytes: bytes):
    submit_url = f"{DATA_COLLECTION_BASE.rstrip('/')}/submit"
    files = {'image': (image_name, image_bytes, 'application/octet-stream')}
    data = {'caption': caption, 'image_name': image_name, 'course': course, 'image_url': image_url}
    
    print(f"[BACKGROUND] Submitting to {submit_url}")
    print(f"[BACKGROUND] Image name: {image_name}")
    print(f"[BACKGROUND] Course: {course}")
    print(f"[BACKGROUND] Caption length: {len(caption)} chars")
    
    try:
        r = requests.post(submit_url, data=data, files=files, timeout=30)
        print(f"[BACKGROUND] Submit response status: {r.status}")
        try:
            resp = r.json()
            print(f"[BACKGROUND] Submit response JSON: {resp}")
            return r.status_code, resp
        except Exception:
            print(f"[BACKGROUND] Submit response text: {r.text}")
            return r.status_code, r.text
    except Exception as e:
        print(f"[BACKGROUND] Submit POST failed: {e}")
        return None, None


def _release_frame(course: str, video: str, frame: str):
    try:
        release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/frame/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(video, safe='')}/{urllib.parse.quote(frame, safe='')}"
        requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10)
    except Exception as e:
        print(f"[BACKGROUND] release frame failed: {e}")


def _release_course(course: str):
    try:
        release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/course/{urllib.parse.quote(course, safe='')}"
        requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10)
    except Exception as e:
        print(f"[BACKGROUND] release course failed: {e}")


# Background worker implementation
def background_worker():
    """Background worker that processes images from the middleware server."""
    print("[BACKGROUND] Starting worker, waiting for model...")
    
    # Wait for model to be ready
    waited = 0
    while waited < 120:
        if vision_language_model is not None:
            break
        time.sleep(1)
        waited += 1
    
    if waited >= 120:
        print("[BACKGROUND] Model not available after timeout")
        return

    print(f"[BACKGROUND] Model {MODEL_CHOICE} ready, starting processing loop")
    
    while True:
        try:
            # Get next course
            courses_url = f"{IMAGE_SERVER_BASE}/courses"
            print(f"[BACKGROUND] Fetching courses from {courses_url}")
            
            try:
                r = requests.get(courses_url, timeout=15)
                r.raise_for_status()
                courses_data = r.json()
                
                if not courses_data.get('courses'):
                    print("[BACKGROUND] No courses found, waiting...")
                    time.sleep(3)
                    continue
                
                # Get first course
                course_entry = courses_data['courses'][0]
                if isinstance(course_entry, dict):
                    course = course_entry.get('course_folder')
                else:
                    course = str(course_entry)
                
                if not course:
                    print("[BACKGROUND] Invalid course entry")
                    time.sleep(2)
                    continue
                
                print(f"[BACKGROUND] Processing course: {course}")
                
                # Get images list
                images_url = f"{IMAGE_SERVER_BASE}/images/{urllib.parse.quote(course, safe='')}"
                r = requests.get(images_url, timeout=15)
                r.raise_for_status()
                images_data = r.json()
                
                if isinstance(images_data, dict):
                    image_list = images_data.get('images', [])
                else:
                    image_list = images_data
                
                if not image_list:
                    print(f"[BACKGROUND] No images found for course {course}")
                    time.sleep(2)
                    continue
                
                print(f"[BACKGROUND] Found {len(image_list)} images")
                
                # Process images
                for img_entry in image_list:
                    try:
                        # Extract filename and metadata
                        if isinstance(img_entry, dict):
                            filename = img_entry.get('filename')
                            if not filename:
                                continue
                        else:
                            filename = str(img_entry)
                        
                        # Download image
                        download_url = f"{IMAGE_SERVER_BASE}/images/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(filename, safe='')}"
                        print(f"[BACKGROUND] Downloading {download_url}")
                        
                        img_bytes, _ = _download_bytes(download_url)
                        if not img_bytes:
                            print(f"[BACKGROUND] Failed to download {filename}")
                            continue
                        
                        # Process with Florence
                        try:
                            pil_img = Image.open(BytesIO(img_bytes)).convert('RGB')
                            
                            model = vision_language_model
                            processor = vision_language_processor
                            
                            print(f"[BACKGROUND] Generating caption for {filename}")
                            caption = process_image_description(model, processor, pil_img)
                            print(f"[BACKGROUND] Generated caption for {filename}:")
                            print("-" * 40)
                            print(caption)
                            print("-" * 40)
                            
                            # Submit result
                            print(f"[BACKGROUND] Submitting caption to {DATA_COLLECTION_BASE}/submit")
                            status, resp = _post_submit(caption, filename, course, download_url, img_bytes)
                            if status and status < 400:
                                print(f"[BACKGROUND] Successfully submitted {filename} (status={status})")
                                if resp:
                                    print(f"[BACKGROUND] Response: {resp}")
                            else:
                                print(f"[BACKGROUND] Failed to submit {filename}: status={status}, response={resp}")
                            
                        except Exception as e:
                            print(f"[BACKGROUND] Error processing {filename}: {e}")
                            continue
                        finally:
                            # Clean up
                            if 'pil_img' in locals():
                                del pil_img
                            if 'img_bytes' in locals():
                                del img_bytes
                        
                        time.sleep(0.5)  # Small delay between images
                        
                    except Exception as e:
                        print(f"[BACKGROUND] Error in image loop: {e}")
                        continue
                
                print(f"[BACKGROUND] Completed course {course}")
                time.sleep(1)
                
            except Exception as e:
                print(f"[BACKGROUND] Error in course loop: {e}")
                time.sleep(5)
                continue
            
        except Exception as e:
            print(f"[BACKGROUND] Main loop error: {e}")
            time.sleep(5)


def _start_worker_thread():
    """Start the background worker thread."""
    t = threading.Thread(target=background_worker, daemon=True)
    t.start()
    return t


# FastAPI endpoints for status/health
@app.get("/")
async def root():
    return {
        "name": "Florence-2 Image Captioning Server",
        "status": "running",
        "model": "Florence-2-large",
        "model_loaded": vision_language_model is not None,
        "device": device
    }

@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "model": "Florence-2-large",
        "model_loaded": vision_language_model is not None,
        "device": device,
        "model_choice": MODEL_CHOICE
    }



@app.get("/analyze")
async def analyze_get(image_url: str = None, model_choice: str = None):
    """Analyze an image by URL. Usage: /analyze?image_url=https://...&model_choice=Florence-2-base"""
    try:
        mc = model_choice or MODEL_CHOICE
        if image_url:
            result = describe_image_from_url(image_url, mc)
            if isinstance(result, dict) and result.get("status") == "success":
                return JSONResponse(content={"success": True, "caption": result.get("caption"), "image_size": result.get("image_size")})
            else:
                return JSONResponse(status_code=400, content={"success": False, "error": result})
        else:
            raise HTTPException(status_code=400, detail="image_url query parameter is required")
    except HTTPException:
        raise
    except Exception as e:
        return JSONResponse(status_code=500, content={"success": False, "error": str(e)})


@app.post("/analyze")
async def analyze_post(file: UploadFile = File(None), model_choice: str = Form(None)):
    """Analyze an uploaded image (multipart/form-data). Returns caption JSON."""
    try:
        if file is None:
            raise HTTPException(status_code=400, detail="file is required")

        content = await file.read()
        try:
            pil_img = Image.open(BytesIO(content)).convert('RGB')
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Failed to read uploaded image: {e}")

        if vision_language_model is None:
            raise HTTPException(status_code=503, detail="Florence-2-large model not loaded")
        
        model = vision_language_model
        processor = vision_language_processor

        caption = process_image_description(model, processor, pil_img)
        return JSONResponse(content={"success": True, "caption": caption})

    except HTTPException:
        raise
    except Exception as e:
        return JSONResponse(status_code=500, content={"success": False, "error": str(e)})

# Get the port from environment variable (for Hugging Face Spaces)
port = int(os.environ.get("PORT", 7860))

# Launch FastAPI with uvicorn when run directly
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=port)