File size: 6,906 Bytes
a32396c
76de2e4
d059378
76de2e4
d059378
 
 
6b4d8dc
d059378
c6b9676
b5dfc9f
03901aa
 
d059378
047f73e
c6b9676
d059378
 
 
047f73e
d059378
 
c6b9676
d059378
 
 
99c4852
d059378
 
 
 
c6b9676
b5dfc9f
061f058
d059378
061f058
 
6b4d8dc
061f058
85a2cee
047f73e
 
 
d059378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6b9676
047f73e
 
d059378
 
 
 
 
 
061f058
047f73e
d059378
047f73e
d059378
6b4d8dc
d059378
061f058
d059378
 
 
 
 
6b4d8dc
d059378
6b4d8dc
d059378
061f058
047f73e
dbe00f2
047f73e
 
 
 
 
 
061f058
d059378
061f058
d059378
dbe00f2
d059378
 
 
 
 
 
 
047f73e
 
d059378
03901aa
d059378
 
 
 
 
 
 
047f73e
d059378
03901aa
d059378
 
 
 
047f73e
d059378
 
047f73e
 
d059378
047f73e
 
d059378
 
 
 
 
 
 
047f73e
 
 
d059378
 
 
 
 
 
047f73e
 
d059378
 
 
 
 
 
 
 
 
047f73e
 
 
 
 
d059378
 
047f73e
d059378
 
047f73e
d059378
047f73e
 
 
 
 
d059378
 
 
 
 
 
047f73e
d059378
 
 
047f73e
d059378
 
047f73e
d059378
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import torch
import requests
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from transformers import AutoProcessor, AutoModelForCausalLM
import uvicorn

# ===== CONFIG =====
DEVICE = "cpu"                      # Use CPU for compatibility
RESIZE_DIM = (512, 512)             # Resize images to this resolution
MAX_IMAGE_SIZE = 10 * 1024 * 1024   # 10MB max image size
TASK = "<MORE_DETAILED_CAPTION>"    # Hardcoded task

# ===== FastAPI App =====
app = FastAPI(
    title="Florence-2 Image Analysis API",
    description="Analyze images using Microsoft's Florence-2 model with detailed captions",
    version="1.0.0"
)

# ===== Request/Response Models =====
class ImageAnalysisRequest(BaseModel):
    image_url: HttpUrl

class ImageAnalysisResponse(BaseModel):
    caption: str
    success: bool
    error_message: str = None

# ===== Load Florence-2 Base Model =====
print("[INFO] Loading Florence-2 model on CPU...")
try:
    MODEL_ID = "microsoft/Florence-2-large"
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        device_map="auto"
    ).eval()
    print("[INFO] Model loaded successfully!")
except Exception as e:
    print(f"[ERROR] Failed to load model: {e}")
    processor = None
    model = None

# ===== Helper Functions =====
def download_image(url: str) -> Image.Image:
    """Download image from URL and return PIL Image"""
    try:
        # Set headers to mimic browser request
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        }
        
        response = requests.get(str(url), headers=headers, timeout=30)
        response.raise_for_status()
        
        # Check content length
        if len(response.content) > MAX_IMAGE_SIZE:
            raise ValueError(f"Image too large: {len(response.content)} bytes (max: {MAX_IMAGE_SIZE})")
        
        # Check if content is actually an image
        content_type = response.headers.get('content-type', '')
        if not content_type.startswith('image/'):
            raise ValueError(f"URL does not point to an image. Content-Type: {content_type}")
        
        image = Image.open(BytesIO(response.content)).convert("RGB")
        return image
        
    except requests.exceptions.RequestException as e:
        raise ValueError(f"Failed to download image: {e}")
    except Exception as e:
        raise ValueError(f"Failed to process image: {e}")

def analyze_image(image: Image.Image) -> str:
    """Analyze image using Florence-2 model with hardcoded task"""
    if not processor or not model:
        raise ValueError("Model not loaded properly")
    
    try:
        # Resize image for faster processing
        image = image.resize(RESIZE_DIM, Image.BILINEAR)

        # Prepare inputs with hardcoded task
        inputs = processor(
            text=TASK,
            images=image,
            return_tensors="pt"
        ).to(DEVICE)

        # Generate caption
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=1024,
                num_beams=3,
                do_sample=False
            )

        # Decode and clean output
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        # Remove the task prompt from the beginning if present
        if generated_text.startswith(TASK):
            generated_text = generated_text[len(TASK):].strip()
            
        return generated_text

    except Exception as e:
        print(f"[ERROR] Exception in analyze_image: {e}")
        raise ValueError(f"Failed to analyze image: {e}")
        
# ===== API Endpoints =====
@app.get("/")
async def root():
    """Health check endpoint"""
    return {
        "message": "Florence-2 Image Analysis API",
        "status": "running",
        "model_loaded": processor is not None and model is not None,
        "task": TASK
    }

@app.get("/health")
async def health_check():
    """Detailed health check"""
    return {
        "status": "healthy" if (processor and model) else "unhealthy",
        "model_loaded": processor is not None and model is not None,
        "device": DEVICE,
        "task": TASK
    }

@app.post("/analyze", response_model=ImageAnalysisResponse)
async def analyze_image_endpoint(request: ImageAnalysisRequest):
    """
    Analyze an image from a URL using Florence-2 model
    Always uses <MORE_DETAILED_CAPTION> task for detailed image descriptions
    """
    try:
        # Validate model is loaded
        if not processor or not model:
            raise HTTPException(
                status_code=503,
                detail="Model not loaded. Please check server logs."
            )
        
        # Download and process image
        print(f"[INFO] Processing image from: {request.image_url}")
        image = download_image(request.image_url)
        print(f"[INFO] Image downloaded successfully: {image.size}")
        
        # Analyze image with hardcoded task
        caption = analyze_image(image)
        print(f"[INFO] Analysis complete")
        
        return ImageAnalysisResponse(
            caption=caption,
            success=True
        )
        
    except HTTPException:
        raise
    except ValueError as e:
        print(f"[ERROR] ValueError: {e}")
        return ImageAnalysisResponse(
            caption="",
            success=False,
            error_message=str(e)
        )
    except Exception as e:
        print(f"[ERROR] Unexpected error: {e}")
        return ImageAnalysisResponse(
            caption="",
            success=False,
            error_message=f"Internal server error: {str(e)}"
        )

@app.get("/analyze")
async def analyze_image_get(image_url: str):
    """
    GET endpoint for quick image analysis
    Usage: /analyze?image_url=https://example.com/image.jpg
    """
    try:
        request = ImageAnalysisRequest(image_url=image_url)
        return await analyze_image_endpoint(request)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

# ===== Main Execution =====
if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860))
    print(f"[INFO] Starting server on port {port}")
    print(f"[INFO] Model status: {'Loaded' if (processor and model) else 'Failed to load'}")
    print(f"[INFO] Task: {TASK}")
    print(f"[INFO] API Documentation: http://localhost:{port}/docs")
    
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=port,
        reload=False
    )