import os import torch import requests from PIL import Image from io import BytesIO from fastapi import FastAPI, HTTPException from pydantic import BaseModel, HttpUrl from transformers import AutoProcessor, AutoModelForCausalLM import uvicorn # ===== CONFIG ===== DEVICE = "cpu" # Use CPU for compatibility RESIZE_DIM = (512, 512) # Resize images to this resolution MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size TASK = "" # Hardcoded task # ===== FastAPI App ===== app = FastAPI( title="Florence-2 Image Analysis API", description="Analyze images using Microsoft's Florence-2 model with detailed captions", version="1.0.0" ) # ===== Request/Response Models ===== class ImageAnalysisRequest(BaseModel): image_url: HttpUrl class ImageAnalysisResponse(BaseModel): caption: str success: bool error_message: str = None # ===== Load Florence-2 Base Model ===== print("[INFO] Loading Florence-2 model on CPU...") try: MODEL_ID = "microsoft/Florence-2-large" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float32, device_map="auto" ).eval() print("[INFO] Model loaded successfully!") except Exception as e: print(f"[ERROR] Failed to load model: {e}") processor = None model = None # ===== Helper Functions ===== def download_image(url: str) -> Image.Image: """Download image from URL and return PIL Image""" try: # Set headers to mimic browser request headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } response = requests.get(str(url), headers=headers, timeout=30) response.raise_for_status() # Check content length if len(response.content) > MAX_IMAGE_SIZE: raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})") # Check if content is actually an image content_type = response.headers.get('content-type', '') if not content_type.startswith('image/'): raise ValueError(f"URL does not point to an image. Content-Type: {content_type}") image = Image.open(BytesIO(response.content)).convert("RGB") return image except requests.exceptions.RequestException as e: raise ValueError(f"Failed to download image: {e}") except Exception as e: raise ValueError(f"Failed to process image: {e}") def analyze_image(image: Image.Image) -> str: """Analyze image using Florence-2 model with hardcoded task""" if not processor or not model: raise ValueError("Model not loaded properly") try: # Resize image for faster processing image = image.resize(RESIZE_DIM, Image.BILINEAR) # Prepare inputs with hardcoded task inputs = processor( text=TASK, images=image, return_tensors="pt" ).to(DEVICE) # Generate caption with torch.no_grad(): generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3, do_sample=False ) # Decode and clean output generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Remove the task prompt from the beginning if present if generated_text.startswith(TASK): generated_text = generated_text[len(TASK):].strip() return generated_text except Exception as e: print(f"[ERROR] Exception in analyze_image: {e}") raise ValueError(f"Failed to analyze image: {e}") # ===== API Endpoints ===== @app.get("/") async def root(): """Health check endpoint""" return { "message": "Florence-2 Image Analysis API", "status": "running", "model_loaded": processor is not None and model is not None, "task": TASK } @app.get("/health") async def health_check(): """Detailed health check""" return { "status": "healthy" if (processor and model) else "unhealthy", "model_loaded": processor is not None and model is not None, "device": DEVICE, "task": TASK } @app.post("/analyze", response_model=ImageAnalysisResponse) async def analyze_image_endpoint(request: ImageAnalysisRequest): """ Analyze an image from a URL using Florence-2 model Always uses task for detailed image descriptions """ try: # Validate model is loaded if not processor or not model: raise HTTPException( status_code=503, detail="Model not loaded. Please check server logs." ) # Download and process image print(f"[INFO] Processing image from: {request.image_url}") image = download_image(request.image_url) print(f"[INFO] Image downloaded successfully: {image.size}") # Analyze image with hardcoded task caption = analyze_image(image) print(f"[INFO] Analysis complete") return ImageAnalysisResponse( caption=caption, success=True ) except HTTPException: raise except ValueError as e: print(f"[ERROR] ValueError: {e}") return ImageAnalysisResponse( caption="", success=False, error_message=str(e) ) except Exception as e: print(f"[ERROR] Unexpected error: {e}") return ImageAnalysisResponse( caption="", success=False, error_message=f"Internal server error: {str(e)}" ) @app.get("/analyze") async def analyze_image_get(image_url: str): """ GET endpoint for quick image analysis Usage: /analyze?image_url=https://example.com/image.jpg """ try: request = ImageAnalysisRequest(image_url=image_url) return await analyze_image_endpoint(request) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) # ===== Main Execution ===== if __name__ == "__main__": port = int(os.getenv("PORT", 7860)) print(f"[INFO] Starting server on port {port}") print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}") print(f"[INFO] Task: {TASK}") print(f"[INFO] API Documentation: http://localhost:{port}/docs") uvicorn.run( app, host="0.0.0.0", port=port, reload=False )