|
|
import os |
|
|
import json |
|
|
import time |
|
|
from typing import Dict |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import torch |
|
|
from transformers import AutoModelForVision2Seq, AutoProcessor |
|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import JSONResponse |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_OFFLINE"] = "1" |
|
|
|
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(False) |
|
|
torch.backends.cuda.enable_math_sdp(True) |
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "microsoft/Florence-2-large" |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
app = FastAPI(title="Florence-2 Image Captioning API") |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
def log_message(message: str): |
|
|
"""Simple logging function""" |
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S") |
|
|
print(f"[{timestamp}] {message}") |
|
|
|
|
|
def load_florence_model(): |
|
|
"""Load Florence-2 model and processor""" |
|
|
global model, processor |
|
|
if model is None or processor is None: |
|
|
try: |
|
|
log_message("[*] Loading Florence-2 model and processor...") |
|
|
|
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True |
|
|
) |
|
|
|
|
|
|
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True, |
|
|
torch_dtype=torch.float32 |
|
|
).to(DEVICE) |
|
|
|
|
|
model.eval() |
|
|
log_message("[ ] Florence-2 loaded and ready.") |
|
|
except Exception as e: |
|
|
log_message(f"[ERROR] Failed to load Florence-2 model: {e}") |
|
|
raise |
|
|
|
|
|
def caption_image(image: Image.Image) -> str: |
|
|
"""Generate detailed caption for an image using Florence-2""" |
|
|
if model is None or processor is None: |
|
|
return "Model not loaded." |
|
|
|
|
|
task_prompt = "<MORE_DETAILED_CAPTION>" |
|
|
prompt = task_prompt |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = processor( |
|
|
text=prompt, |
|
|
images=image, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True |
|
|
).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
pixel_values=inputs["pixel_values"], |
|
|
max_new_tokens=1350, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
num_beams=3, |
|
|
) |
|
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
log_message(f"[!] Caption generation failed: {e}") |
|
|
return "Captioning error." |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load model on startup""" |
|
|
load_florence_model() |
|
|
|
|
|
@app.post("/caption") |
|
|
async def create_caption(file: UploadFile = File(...)) -> Dict: |
|
|
""" |
|
|
API endpoint to receive an image and return its caption |
|
|
""" |
|
|
try: |
|
|
log_message(f"[API] Received image: {file.filename}") |
|
|
|
|
|
|
|
|
contents = await file.read() |
|
|
image = Image.open(BytesIO(contents)).convert("RGB") |
|
|
|
|
|
|
|
|
log_message(f"[API] Generating caption for {file.filename}") |
|
|
caption = caption_image(image) |
|
|
|
|
|
log_message(f"[API] Caption generated for {file.filename}: {caption[:100]}...") |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"filename": file.filename, |
|
|
"caption": caption |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error processing image: {str(e)}" |
|
|
log_message(f"[ERROR] {error_msg}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"status": "error", |
|
|
"message": error_msg |
|
|
} |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
log_message("Starting Florence-2 Vision Analysis API Server") |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |