|
|
import os |
|
|
import torch |
|
|
import requests |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel, HttpUrl |
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
|
RESIZE_DIM = (512, 512) |
|
|
MAX_IMAGE_SIZE = 10 * 1024 * 1024 |
|
|
TASK = "<MORE_DETAILED_CAPTION>" |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Florence-2 Image Analysis API", |
|
|
description="Analyze images using Microsoft's Florence-2 model with detailed captions", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
class ImageAnalysisRequest(BaseModel): |
|
|
image_url: HttpUrl |
|
|
|
|
|
class ImageAnalysisResponse(BaseModel): |
|
|
caption: str |
|
|
success: bool |
|
|
error_message: str = None |
|
|
|
|
|
|
|
|
print("[INFO] Loading Florence-2 model on CPU...") |
|
|
try: |
|
|
MODEL_ID = "microsoft/Florence-2-large" |
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="auto" |
|
|
).eval() |
|
|
print("[INFO] Model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to load model: {e}") |
|
|
processor = None |
|
|
model = None |
|
|
|
|
|
|
|
|
def download_image(url: str) -> Image.Image: |
|
|
"""Download image from URL and return PIL Image""" |
|
|
try: |
|
|
|
|
|
headers = { |
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' |
|
|
} |
|
|
|
|
|
response = requests.get(str(url), headers=headers, timeout=30) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
if len(response.content) > MAX_IMAGE_SIZE: |
|
|
raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})") |
|
|
|
|
|
|
|
|
content_type = response.headers.get('content-type', '') |
|
|
if not content_type.startswith('image/'): |
|
|
raise ValueError(f"URL does not point to an image. Content-Type: {content_type}") |
|
|
|
|
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
|
return image |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
raise ValueError(f"Failed to download image: {e}") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to process image: {e}") |
|
|
|
|
|
def analyze_image(image: Image.Image) -> str: |
|
|
"""Analyze image using Florence-2 model with hardcoded task""" |
|
|
if not processor or not model: |
|
|
raise ValueError("Model not loaded properly") |
|
|
|
|
|
try: |
|
|
|
|
|
image = image.resize(RESIZE_DIM, Image.BILINEAR) |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=TASK, |
|
|
images=image, |
|
|
return_tensors="pt" |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
pixel_values=inputs["pixel_values"], |
|
|
max_new_tokens=1024, |
|
|
num_beams=3, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
|
if generated_text.startswith(TASK): |
|
|
generated_text = generated_text[len(TASK):].strip() |
|
|
|
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] Exception in analyze_image: {e}") |
|
|
raise ValueError(f"Failed to analyze image: {e}") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"message": "Florence-2 Image Analysis API", |
|
|
"status": "running", |
|
|
"model_loaded": processor is not None and model is not None, |
|
|
"task": TASK |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Detailed health check""" |
|
|
return { |
|
|
"status": "healthy" if (processor and model) else "unhealthy", |
|
|
"model_loaded": processor is not None and model is not None, |
|
|
"device": DEVICE, |
|
|
"task": TASK |
|
|
} |
|
|
|
|
|
@app.post("/analyze", response_model=ImageAnalysisResponse) |
|
|
async def analyze_image_endpoint(request: ImageAnalysisRequest): |
|
|
""" |
|
|
Analyze an image from a URL using Florence-2 model |
|
|
Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not processor or not model: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Model not loaded. Please check server logs." |
|
|
) |
|
|
|
|
|
|
|
|
print(f"[INFO] Processing image from: {request.image_url}") |
|
|
image = download_image(request.image_url) |
|
|
print(f"[INFO] Image downloaded successfully: {image.size}") |
|
|
|
|
|
|
|
|
caption = analyze_image(image) |
|
|
print(f"[INFO] Analysis complete") |
|
|
|
|
|
return ImageAnalysisResponse( |
|
|
caption=caption, |
|
|
success=True |
|
|
) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except ValueError as e: |
|
|
print(f"[ERROR] ValueError: {e}") |
|
|
return ImageAnalysisResponse( |
|
|
caption="", |
|
|
success=False, |
|
|
error_message=str(e) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Unexpected error: {e}") |
|
|
return ImageAnalysisResponse( |
|
|
caption="", |
|
|
success=False, |
|
|
error_message=f"Internal server error: {str(e)}" |
|
|
) |
|
|
|
|
|
@app.get("/analyze") |
|
|
async def analyze_image_get(image_url: str): |
|
|
""" |
|
|
GET endpoint for quick image analysis |
|
|
Usage: /analyze?image_url=https://example.com/image.jpg |
|
|
""" |
|
|
try: |
|
|
request = ImageAnalysisRequest(image_url=image_url) |
|
|
return await analyze_image_endpoint(request) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
print(f"[INFO] Starting server on port {port}") |
|
|
print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}") |
|
|
print(f"[INFO] Task: {TASK}") |
|
|
print(f"[INFO] API Documentation: http://localhost:{port}/docs") |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=port, |
|
|
reload=False |
|
|
) |