Spaces:
Sleeping
Sleeping
| import subprocess | |
| import torch | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import os | |
| import threading | |
| import time | |
| import urllib.parse | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
| from fastapi.responses import JSONResponse | |
| app = FastAPI( | |
| title="Florence-2 Image Captioning Server", | |
| description="Auto-captions images from middleware server using Florence-2" | |
| ) | |
| import threading | |
| import time | |
| import urllib.parse | |
| # Attempt to install flash-attn | |
| try: | |
| subprocess.run('pip install flash-attn --no-build-isolation timm einops', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True) | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error installing flash-attn: {e}") | |
| print("Continuing without flash-attn.") | |
| # Determine the device to use | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Florence-2-large model and processor | |
| try: | |
| vision_language_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval() | |
| vision_language_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True) | |
| print("✓ Florence-2-large model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading Florence-2-large model: {e}") | |
| vision_language_model = None | |
| vision_language_processor = None | |
| def load_image_from_url(image_url): | |
| """Load an image from a URL.""" | |
| try: | |
| response = requests.get(image_url, timeout=30) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)) | |
| return image.convert('RGB') | |
| except Exception as e: | |
| raise ValueError(f"Error loading image from URL: {e}") | |
| def process_image_description(model, processor, image): | |
| """Process an image and generate description using the specified model.""" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| processed_description = processor.post_process_generation( | |
| generated_text, | |
| task="<MORE_DETAILED_CAPTION>", | |
| image_size=(image.width, image.height) | |
| ) | |
| image_description = processed_description["<MORE_DETAILED_CAPTION>"] | |
| return image_description | |
| def describe_image(uploaded_image, model_choice): | |
| """Generate description from uploaded image.""" | |
| if uploaded_image is None: | |
| return "Please upload an image." | |
| if vision_language_model is None: | |
| return "Florence-2-large model failed to load." | |
| model = vision_language_model | |
| processor = vision_language_processor | |
| try: | |
| return process_image_description(model, processor, uploaded_image) | |
| except Exception as e: | |
| return f"Error generating caption: {str(e)}" | |
| def describe_image_from_url(image_url, model_choice=None): | |
| """Generate description from image URL.""" | |
| try: | |
| if not image_url: | |
| return {"error": "image_url is required"} | |
| if vision_language_model is None: | |
| return {"error": "Florence-2-large model not available"} | |
| # Load image from URL | |
| image = load_image_from_url(image_url) | |
| # Use the loaded large model | |
| model = vision_language_model | |
| processor = vision_language_processor | |
| # Generate caption | |
| caption = process_image_description(model, processor, image) | |
| return { | |
| "status": "success", | |
| "model": model_choice, | |
| "caption": caption, | |
| "image_size": {"width": image.width, "height": image.height} | |
| } | |
| except Exception as e: | |
| return {"error": f"Error processing image: {str(e)}"} | |
| IMAGE_SERVER_BASE = os.getenv("IMAGE_SERVER_BASE", " ") | |
| DATA_COLLECTION_BASE = os.getenv("DATA_COLLECTION_BASE", "https://fred808-flow.hf.space") | |
| REQUESTER_ID = os.getenv("FLO_REQUESTER_ID", f"florence-2-{os.getpid()}") | |
| MODEL_CHOICE = "Florence-2-large" # Always use large model | |
| def sanitize_name(name: str, max_len: int = 200) -> str: | |
| """Sanitize a filename while preserving extension.""" | |
| import re | |
| name = str(name).strip() | |
| # replace spaces with underscores | |
| name = re.sub(r"\s+", "_", name) | |
| # remove any characters not alphanumeric, dot, dash, or underscore | |
| name = re.sub(r"[^A-Za-z0-9_.-]", "", name) | |
| if len(name) > max_len: | |
| base, ext = os.path.splitext(name) | |
| name = base[: max_len - len(ext)] + ext | |
| return name or "file" | |
| def _build_download_url(course: str, video: str, frame: str) -> str: | |
| """Build download URL with proper encoding of all path segments.""" | |
| # The middleware /download endpoint expects the 'file' parameter to be | |
| # a path relative to the course folder (e.g. "video_name/frame.jpg"). | |
| # Frames live under a "{base_course}_frames" folder. | |
| base_course = course | |
| if not base_course.endswith("_frames"): | |
| course_dir = f"{base_course}_frames" | |
| else: | |
| course_dir = base_course | |
| base_course = course_dir[:-7] # strip _frames for consistency | |
| # Sanitize and encode path segments | |
| safe_course = sanitize_name(course_dir) | |
| safe_video = sanitize_name(video) | |
| safe_frame = sanitize_name(frame) | |
| file_param = f"{safe_video}/{safe_frame}" | |
| url = f"{IMAGE_SERVER_BASE.rstrip('/')}/download?course={urllib.parse.quote(safe_course, safe='')}&file={urllib.parse.quote(file_param, safe='')}" | |
| print(f"[BACKGROUND] Built URL: {url}") | |
| return url, safe_frame | |
| def _download_bytes(url: str, timeout: int = 30, chunk_size=32768): | |
| try: | |
| print(f"[BACKGROUND] Starting download: {url}") | |
| response = requests.get(url, timeout=timeout, stream=True) | |
| response.raise_for_status() | |
| content = BytesIO() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| print(f"[BACKGROUND] Total size: {total_size} bytes") | |
| bytes_read = 0 | |
| for chunk in response.iter_content(chunk_size=chunk_size): | |
| if chunk: | |
| content.write(chunk) | |
| bytes_read += len(chunk) | |
| if total_size: | |
| print(f"\rDownloading: {bytes_read}/{total_size} bytes ({(bytes_read/total_size)*100:.1f}%)", end="", flush=True) | |
| print() # New line after progress | |
| print(f"[BACKGROUND] Download complete: {bytes_read} bytes") | |
| return content.getvalue(), response.headers.get('content-type') | |
| except Exception as e: | |
| print(f"[BACKGROUND] download failed {url}: {e}") | |
| return None, None | |
| def _post_submit(caption: str, image_name: str, course: str, image_url: str, image_bytes: bytes): | |
| submit_url = f"{DATA_COLLECTION_BASE.rstrip('/')}/submit" | |
| files = {'image': (image_name, image_bytes, 'application/octet-stream')} | |
| data = {'caption': caption, 'image_name': image_name, 'course': course, 'image_url': image_url} | |
| print(f"[BACKGROUND] Submitting to {submit_url}") | |
| print(f"[BACKGROUND] Image name: {image_name}") | |
| print(f"[BACKGROUND] Course: {course}") | |
| print(f"[BACKGROUND] Caption length: {len(caption)} chars") | |
| try: | |
| r = requests.post(submit_url, data=data, files=files, timeout=30) | |
| print(f"[BACKGROUND] Submit response status: {r.status}") | |
| try: | |
| resp = r.json() | |
| print(f"[BACKGROUND] Submit response JSON: {resp}") | |
| return r.status_code, resp | |
| except Exception: | |
| print(f"[BACKGROUND] Submit response text: {r.text}") | |
| return r.status_code, r.text | |
| except Exception as e: | |
| print(f"[BACKGROUND] Submit POST failed: {e}") | |
| return None, None | |
| def _release_frame(course: str, video: str, frame: str): | |
| try: | |
| release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/frame/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(video, safe='')}/{urllib.parse.quote(frame, safe='')}" | |
| requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10) | |
| except Exception as e: | |
| print(f"[BACKGROUND] release frame failed: {e}") | |
| def _release_course(course: str): | |
| try: | |
| release_url = f"{IMAGE_SERVER_BASE.rstrip('/')}/middleware/release/course/{urllib.parse.quote(course, safe='')}" | |
| requests.post(release_url, params={"requester_id": REQUESTER_ID}, timeout=10) | |
| except Exception as e: | |
| print(f"[BACKGROUND] release course failed: {e}") | |
| # Background worker implementation | |
| def background_worker(): | |
| """Background worker that processes images from the middleware server.""" | |
| print("[BACKGROUND] Starting worker, waiting for model...") | |
| # Wait for model to be ready | |
| waited = 0 | |
| while waited < 120: | |
| if vision_language_model is not None: | |
| break | |
| time.sleep(1) | |
| waited += 1 | |
| if waited >= 120: | |
| print("[BACKGROUND] Model not available after timeout") | |
| return | |
| print(f"[BACKGROUND] Model {MODEL_CHOICE} ready, starting processing loop") | |
| while True: | |
| try: | |
| # Get next course | |
| courses_url = f"{IMAGE_SERVER_BASE}/courses" | |
| print(f"[BACKGROUND] Fetching courses from {courses_url}") | |
| try: | |
| r = requests.get(courses_url, timeout=15) | |
| r.raise_for_status() | |
| courses_data = r.json() | |
| if not courses_data.get('courses'): | |
| print("[BACKGROUND] No courses found, waiting...") | |
| time.sleep(3) | |
| continue | |
| # Get first course | |
| course_entry = courses_data['courses'][0] | |
| if isinstance(course_entry, dict): | |
| course = course_entry.get('course_folder') | |
| else: | |
| course = str(course_entry) | |
| if not course: | |
| print("[BACKGROUND] Invalid course entry") | |
| time.sleep(2) | |
| continue | |
| print(f"[BACKGROUND] Processing course: {course}") | |
| # Get images list | |
| images_url = f"{IMAGE_SERVER_BASE}/images/{urllib.parse.quote(course, safe='')}" | |
| r = requests.get(images_url, timeout=15) | |
| r.raise_for_status() | |
| images_data = r.json() | |
| if isinstance(images_data, dict): | |
| image_list = images_data.get('images', []) | |
| else: | |
| image_list = images_data | |
| if not image_list: | |
| print(f"[BACKGROUND] No images found for course {course}") | |
| time.sleep(2) | |
| continue | |
| print(f"[BACKGROUND] Found {len(image_list)} images") | |
| # Process images | |
| for img_entry in image_list: | |
| try: | |
| # Extract filename and metadata | |
| if isinstance(img_entry, dict): | |
| filename = img_entry.get('filename') | |
| if not filename: | |
| continue | |
| else: | |
| filename = str(img_entry) | |
| # Download image | |
| download_url = f"{IMAGE_SERVER_BASE}/images/{urllib.parse.quote(course, safe='')}/{urllib.parse.quote(filename, safe='')}" | |
| print(f"[BACKGROUND] Downloading {download_url}") | |
| img_bytes, _ = _download_bytes(download_url) | |
| if not img_bytes: | |
| print(f"[BACKGROUND] Failed to download {filename}") | |
| continue | |
| # Process with Florence | |
| try: | |
| pil_img = Image.open(BytesIO(img_bytes)).convert('RGB') | |
| model = vision_language_model | |
| processor = vision_language_processor | |
| print(f"[BACKGROUND] Generating caption for {filename}") | |
| caption = process_image_description(model, processor, pil_img) | |
| print(f"[BACKGROUND] Generated caption for {filename}:") | |
| print("-" * 40) | |
| print(caption) | |
| print("-" * 40) | |
| # Submit result | |
| print(f"[BACKGROUND] Submitting caption to {DATA_COLLECTION_BASE}/submit") | |
| status, resp = _post_submit(caption, filename, course, download_url, img_bytes) | |
| if status and status < 400: | |
| print(f"[BACKGROUND] Successfully submitted {filename} (status={status})") | |
| if resp: | |
| print(f"[BACKGROUND] Response: {resp}") | |
| else: | |
| print(f"[BACKGROUND] Failed to submit {filename}: status={status}, response={resp}") | |
| except Exception as e: | |
| print(f"[BACKGROUND] Error processing {filename}: {e}") | |
| continue | |
| finally: | |
| # Clean up | |
| if 'pil_img' in locals(): | |
| del pil_img | |
| if 'img_bytes' in locals(): | |
| del img_bytes | |
| time.sleep(0.5) # Small delay between images | |
| except Exception as e: | |
| print(f"[BACKGROUND] Error in image loop: {e}") | |
| continue | |
| print(f"[BACKGROUND] Completed course {course}") | |
| time.sleep(1) | |
| except Exception as e: | |
| print(f"[BACKGROUND] Error in course loop: {e}") | |
| time.sleep(5) | |
| continue | |
| except Exception as e: | |
| print(f"[BACKGROUND] Main loop error: {e}") | |
| time.sleep(5) | |
| def _start_worker_thread(): | |
| """Start the background worker thread.""" | |
| t = threading.Thread(target=background_worker, daemon=True) | |
| t.start() | |
| return t | |
| # FastAPI endpoints for status/health | |
| async def root(): | |
| return { | |
| "name": "Florence-2 Image Captioning Server", | |
| "status": "running", | |
| "model": "Florence-2-large", | |
| "model_loaded": vision_language_model is not None, | |
| "device": device | |
| } | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "model": "Florence-2-large", | |
| "model_loaded": vision_language_model is not None, | |
| "device": device, | |
| "model_choice": MODEL_CHOICE | |
| } | |
| async def analyze_get(image_url: str = None, model_choice: str = None): | |
| """Analyze an image by URL. Usage: /analyze?image_url=https://...&model_choice=Florence-2-base""" | |
| try: | |
| mc = model_choice or MODEL_CHOICE | |
| if image_url: | |
| result = describe_image_from_url(image_url, mc) | |
| if isinstance(result, dict) and result.get("status") == "success": | |
| return JSONResponse(content={"success": True, "caption": result.get("caption"), "image_size": result.get("image_size")}) | |
| else: | |
| return JSONResponse(status_code=400, content={"success": False, "error": result}) | |
| else: | |
| raise HTTPException(status_code=400, detail="image_url query parameter is required") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"success": False, "error": str(e)}) | |
| async def analyze_post(file: UploadFile = File(None), model_choice: str = Form(None)): | |
| """Analyze an uploaded image (multipart/form-data). Returns caption JSON.""" | |
| try: | |
| if file is None: | |
| raise HTTPException(status_code=400, detail="file is required") | |
| content = await file.read() | |
| try: | |
| pil_img = Image.open(BytesIO(content)).convert('RGB') | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed to read uploaded image: {e}") | |
| if vision_language_model is None: | |
| raise HTTPException(status_code=503, detail="Florence-2-large model not loaded") | |
| model = vision_language_model | |
| processor = vision_language_processor | |
| caption = process_image_description(model, processor, pil_img) | |
| return JSONResponse(content={"success": True, "caption": caption}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"success": False, "error": str(e)}) | |
| # Get the port from environment variable (for Hugging Face Spaces) | |
| port = int(os.environ.get("PORT", 7860)) | |
| # Launch FastAPI with uvicorn when run directly | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=port) |