File size: 6,906 Bytes
a32396c 76de2e4 d059378 76de2e4 d059378 6b4d8dc d059378 c6b9676 b5dfc9f 03901aa d059378 047f73e c6b9676 d059378 047f73e d059378 c6b9676 d059378 99c4852 d059378 c6b9676 b5dfc9f 061f058 d059378 061f058 6b4d8dc 061f058 85a2cee 047f73e d059378 c6b9676 047f73e d059378 061f058 047f73e d059378 047f73e d059378 6b4d8dc d059378 061f058 d059378 6b4d8dc d059378 6b4d8dc d059378 061f058 047f73e dbe00f2 047f73e 061f058 d059378 061f058 d059378 dbe00f2 d059378 047f73e d059378 03901aa d059378 047f73e d059378 03901aa d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 047f73e d059378 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import os
import torch
import requests
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from transformers import AutoProcessor, AutoModelForCausalLM
import uvicorn
# ===== CONFIG =====
DEVICE = "cpu" # Use CPU for compatibility
RESIZE_DIM = (512, 512) # Resize images to this resolution
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
TASK = "<MORE_DETAILED_CAPTION>" # Hardcoded task
# ===== FastAPI App =====
app = FastAPI(
title="Florence-2 Image Analysis API",
description="Analyze images using Microsoft's Florence-2 model with detailed captions",
version="1.0.0"
)
# ===== Request/Response Models =====
class ImageAnalysisRequest(BaseModel):
image_url: HttpUrl
class ImageAnalysisResponse(BaseModel):
caption: str
success: bool
error_message: str = None
# ===== Load Florence-2 Base Model =====
print("[INFO] Loading Florence-2 model on CPU...")
try:
MODEL_ID = "microsoft/Florence-2-large"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="auto"
).eval()
print("[INFO] Model loaded successfully!")
except Exception as e:
print(f"[ERROR] Failed to load model: {e}")
processor = None
model = None
# ===== Helper Functions =====
def download_image(url: str) -> Image.Image:
"""Download image from URL and return PIL Image"""
try:
# Set headers to mimic browser request
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(str(url), headers=headers, timeout=30)
response.raise_for_status()
# Check content length
if len(response.content) > MAX_IMAGE_SIZE:
raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})")
# Check if content is actually an image
content_type = response.headers.get('content-type', '')
if not content_type.startswith('image/'):
raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to download image: {e}")
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
def analyze_image(image: Image.Image) -> str:
"""Analyze image using Florence-2 model with hardcoded task"""
if not processor or not model:
raise ValueError("Model not loaded properly")
try:
# Resize image for faster processing
image = image.resize(RESIZE_DIM, Image.BILINEAR)
# Prepare inputs with hardcoded task
inputs = processor(
text=TASK,
images=image,
return_tensors="pt"
).to(DEVICE)
# Generate caption
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
# Decode and clean output
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Remove the task prompt from the beginning if present
if generated_text.startswith(TASK):
generated_text = generated_text[len(TASK):].strip()
return generated_text
except Exception as e:
print(f"[ERROR] Exception in analyze_image: {e}")
raise ValueError(f"Failed to analyze image: {e}")
# ===== API Endpoints =====
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"message": "Florence-2 Image Analysis API",
"status": "running",
"model_loaded": processor is not None and model is not None,
"task": TASK
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
return {
"status": "healthy" if (processor and model) else "unhealthy",
"model_loaded": processor is not None and model is not None,
"device": DEVICE,
"task": TASK
}
@app.post("/analyze", response_model=ImageAnalysisResponse)
async def analyze_image_endpoint(request: ImageAnalysisRequest):
"""
Analyze an image from a URL using Florence-2 model
Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
"""
try:
# Validate model is loaded
if not processor or not model:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please check server logs."
)
# Download and process image
print(f"[INFO] Processing image from: {request.image_url}")
image = download_image(request.image_url)
print(f"[INFO] Image downloaded successfully: {image.size}")
# Analyze image with hardcoded task
caption = analyze_image(image)
print(f"[INFO] Analysis complete")
return ImageAnalysisResponse(
caption=caption,
success=True
)
except HTTPException:
raise
except ValueError as e:
print(f"[ERROR] ValueError: {e}")
return ImageAnalysisResponse(
caption="",
success=False,
error_message=str(e)
)
except Exception as e:
print(f"[ERROR] Unexpected error: {e}")
return ImageAnalysisResponse(
caption="",
success=False,
error_message=f"Internal server error: {str(e)}"
)
@app.get("/analyze")
async def analyze_image_get(image_url: str):
"""
GET endpoint for quick image analysis
Usage: /analyze?image_url=https://example.com/image.jpg
"""
try:
request = ImageAnalysisRequest(image_url=image_url)
return await analyze_image_endpoint(request)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# ===== Main Execution =====
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
print(f"[INFO] Starting server on port {port}")
print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}")
print(f"[INFO] Task: {TASK}")
print(f"[INFO] API Documentation: http://localhost:{port}/docs")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
reload=False
) |