PIL2 / app.py
Fred808's picture
Update app.py
047f73e verified
raw
history blame
6.91 kB
import os
import torch
import requests
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from transformers import AutoProcessor, AutoModelForCausalLM
import uvicorn
# ===== CONFIG =====
DEVICE = "cpu" # Use CPU for compatibility
RESIZE_DIM = (512, 512) # Resize images to this resolution
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB max image size
TASK = "<MORE_DETAILED_CAPTION>" # Hardcoded task
# ===== FastAPI App =====
app = FastAPI(
title="Florence-2 Image Analysis API",
description="Analyze images using Microsoft's Florence-2 model with detailed captions",
version="1.0.0"
)
# ===== Request/Response Models =====
class ImageAnalysisRequest(BaseModel):
image_url: HttpUrl
class ImageAnalysisResponse(BaseModel):
caption: str
success: bool
error_message: str = None
# ===== Load Florence-2 Base Model =====
print("[INFO] Loading Florence-2 model on CPU...")
try:
MODEL_ID = "microsoft/Florence-2-large"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="auto"
).eval()
print("[INFO] Model loaded successfully!")
except Exception as e:
print(f"[ERROR] Failed to load model: {e}")
processor = None
model = None
# ===== Helper Functions =====
def download_image(url: str) -> Image.Image:
"""Download image from URL and return PIL Image"""
try:
# Set headers to mimic browser request
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(str(url), headers=headers, timeout=30)
response.raise_for_status()
# Check content length
if len(response.content) > MAX_IMAGE_SIZE:
raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})")
# Check if content is actually an image
content_type = response.headers.get('content-type', '')
if not content_type.startswith('image/'):
raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to download image: {e}")
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
def analyze_image(image: Image.Image) -> str:
"""Analyze image using Florence-2 model with hardcoded task"""
if not processor or not model:
raise ValueError("Model not loaded properly")
try:
# Resize image for faster processing
image = image.resize(RESIZE_DIM, Image.BILINEAR)
# Prepare inputs with hardcoded task
inputs = processor(
text=TASK,
images=image,
return_tensors="pt"
).to(DEVICE)
# Generate caption
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
# Decode and clean output
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Remove the task prompt from the beginning if present
if generated_text.startswith(TASK):
generated_text = generated_text[len(TASK):].strip()
return generated_text
except Exception as e:
print(f"[ERROR] Exception in analyze_image: {e}")
raise ValueError(f"Failed to analyze image: {e}")
# ===== API Endpoints =====
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"message": "Florence-2 Image Analysis API",
"status": "running",
"model_loaded": processor is not None and model is not None,
"task": TASK
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
return {
"status": "healthy" if (processor and model) else "unhealthy",
"model_loaded": processor is not None and model is not None,
"device": DEVICE,
"task": TASK
}
@app.post("/analyze", response_model=ImageAnalysisResponse)
async def analyze_image_endpoint(request: ImageAnalysisRequest):
"""
Analyze an image from a URL using Florence-2 model
Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
"""
try:
# Validate model is loaded
if not processor or not model:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please check server logs."
)
# Download and process image
print(f"[INFO] Processing image from: {request.image_url}")
image = download_image(request.image_url)
print(f"[INFO] Image downloaded successfully: {image.size}")
# Analyze image with hardcoded task
caption = analyze_image(image)
print(f"[INFO] Analysis complete")
return ImageAnalysisResponse(
caption=caption,
success=True
)
except HTTPException:
raise
except ValueError as e:
print(f"[ERROR] ValueError: {e}")
return ImageAnalysisResponse(
caption="",
success=False,
error_message=str(e)
)
except Exception as e:
print(f"[ERROR] Unexpected error: {e}")
return ImageAnalysisResponse(
caption="",
success=False,
error_message=f"Internal server error: {str(e)}"
)
@app.get("/analyze")
async def analyze_image_get(image_url: str):
"""
GET endpoint for quick image analysis
Usage: /analyze?image_url=https://example.com/image.jpg
"""
try:
request = ImageAnalysisRequest(image_url=image_url)
return await analyze_image_endpoint(request)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# ===== Main Execution =====
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
print(f"[INFO] Starting server on port {port}")
print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}")
print(f"[INFO] Task: {TASK}")
print(f"[INFO] API Documentation: http://localhost:{port}/docs")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
reload=False
)