| import os
|
| import sys
|
| import time
|
| import subprocess
|
| import numpy as np
|
| from PIL import Image
|
| from io import BytesIO
|
| import requests
|
| import threading
|
|
|
|
|
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
| from fastapi.responses import JSONResponse
|
| import uvicorn
|
|
|
|
|
| def setup_environment():
|
| print("--- Setting up environment ---")
|
| dependencies = ["huggingface_hub", "onnxruntime", "transformers", "pillow", "numpy"]
|
| try:
|
| import huggingface_hub
|
| import onnxruntime
|
| import transformers
|
| print("Dependencies already satisfied.")
|
| except ImportError:
|
| print("Installing dependencies...")
|
| subprocess.check_call([sys.executable, "-m", "pip", "install"] + dependencies)
|
|
|
|
|
| def download_model(repo_id="Heliosoph/florence-2-base-ft-quantized-onnx", local_dir="florence2_quantized"):
|
| from huggingface_hub import snapshot_download
|
| if not os.path.exists(local_dir):
|
| print(f"--- Downloading model from {repo_id} ---")
|
| snapshot_download(repo_id=repo_id, local_dir=local_dir)
|
| print("Download complete.")
|
| else:
|
| print(f"Model directory '{local_dir}' already exists.")
|
|
|
|
|
| class Florence2ONNXEngine:
|
| def __init__(self, model_dir="florence2_quantized"):
|
| import onnxruntime as ort
|
| from transformers import CLIPImageProcessor, AutoTokenizer
|
|
|
| self.model_dir = model_dir
|
| print("--- Initializing ONNX Engine ---")
|
|
|
|
|
| self.image_processor = CLIPImageProcessor.from_pretrained("microsoft/Florence-2-base-ft")
|
| self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
|
|
|
|
|
| providers = ['CPUExecutionProvider']
|
| self.vision_session = ort.InferenceSession(os.path.join(model_dir, 'vision_encoder_quantized.onnx'), providers=providers)
|
| self.embed_session = ort.InferenceSession(os.path.join(model_dir, 'embed_tokens_quantized.onnx'), providers=providers)
|
| self.encoder_session = ort.InferenceSession(os.path.join(model_dir, 'encoder_model_quantized.onnx'), providers=providers)
|
| self.decoder_session = ort.InferenceSession(os.path.join(model_dir, 'decoder_model_quantized.onnx'), providers=providers)
|
| print("✓ Florence-2 ONNX Engine initialized successfully")
|
|
|
| def generate_caption(self, image_path=None, image_array=None, task_prompt="<MORE_DETAILED_CAPTION>", max_new_tokens=1024):
|
| """Generate caption from image path or PIL Image object"""
|
| if image_path:
|
| image = Image.open(image_path).convert("RGB")
|
| elif image_array is not None and isinstance(image_array, Image.Image):
|
| image = image_array.convert("RGB")
|
| else:
|
| raise ValueError("Either image_path or image_array must be provided")
|
|
|
| print(f"--- Running Inference (Max Tokens: {max_new_tokens}) ---")
|
| pixel_values = self.image_processor(images=image, return_tensors="np")['pixel_values']
|
|
|
|
|
| prompt_map = {
|
| "<CAPTION>": "What does the image describe?",
|
| "<DETAILED_CAPTION>": "Describe this image in detail.",
|
| "<MORE_DETAILED_CAPTION>": "Describe this image in great detail with every object and background."
|
| }
|
| text_prompt = prompt_map.get(task_prompt, task_prompt)
|
| input_ids = self.tokenizer(text_prompt, return_tensors="np")['input_ids']
|
|
|
|
|
| start_time = time.time()
|
| image_features = self.vision_session.run(None, {'pixel_values': pixel_values})[0]
|
|
|
|
|
| text_embeds = self.embed_session.run(None, {'input_ids': input_ids})[0]
|
|
|
|
|
| combined_embeds = np.concatenate([image_features, text_embeds], axis=1)
|
| attention_mask = np.ones((1, combined_embeds.shape[1]), dtype=np.int64)
|
| encoder_outputs = self.encoder_session.run(None, {
|
| 'inputs_embeds': combined_embeds,
|
| 'attention_mask': attention_mask
|
| })
|
| last_hidden_state = encoder_outputs[0]
|
|
|
|
|
| generated_ids = [2]
|
| min_new_tokens = 250
|
| repetition_penalty = 1.5
|
|
|
| for i in range(max_new_tokens):
|
| decoder_input_ids = np.array([generated_ids], dtype=np.int64)
|
| decoder_embeds = self.embed_session.run(None, {'input_ids': decoder_input_ids})[0]
|
|
|
| logits = self.decoder_session.run(None, {
|
| 'inputs_embeds': decoder_embeds,
|
| 'encoder_hidden_states': last_hidden_state,
|
| 'encoder_attention_mask': attention_mask
|
| })[0]
|
|
|
|
|
| current_logits = logits[0, -1, :].copy()
|
| for prev_token in set(generated_ids[-50:]):
|
| if current_logits[prev_token] > 0:
|
| current_logits[prev_token] /= repetition_penalty
|
| else:
|
| current_logits[prev_token] *= repetition_penalty
|
|
|
| next_token = np.argmax(current_logits)
|
|
|
| if next_token == 2 and i < min_new_tokens:
|
|
|
| current_logits[2] = -1e9
|
| next_token = np.argmax(current_logits)
|
| if next_token == 2: break
|
| generated_ids.append(next_token)
|
|
|
| if (i + 1) % 50 == 0:
|
| print(f"Generated {i+1} tokens...")
|
|
|
| end_time = time.time()
|
| caption = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
| print(f"Inference complete in {end_time - start_time:.2f}s")
|
| return caption
|
|
|
|
|
|
|
| engine = None
|
|
|
| def initialize_engine():
|
| """Initialize the Florence2 ONNX engine"""
|
| global engine
|
| setup_environment()
|
| download_model()
|
| engine = Florence2ONNXEngine()
|
|
|
|
|
| app = FastAPI(
|
| title="Florence-2 ONNX Image Captioning Server",
|
| description="Auto-captions images using Florence-2 ONNX models"
|
| )
|
|
|
| def load_image_from_url(image_url: str) -> Image.Image:
|
| """Load an image from a URL."""
|
| try:
|
| response = requests.get(image_url, timeout=30)
|
| response.raise_for_status()
|
| image = Image.open(BytesIO(response.content))
|
| return image.convert('RGB')
|
| except Exception as e:
|
| raise ValueError(f"Error loading image from URL: {e}")
|
|
|
| def load_image_from_bytes(image_bytes: bytes) -> Image.Image:
|
| """Load an image from bytes."""
|
| try:
|
| image = Image.open(BytesIO(image_bytes))
|
| return image.convert('RGB')
|
| except Exception as e:
|
| raise ValueError(f"Error loading image from bytes: {e}")
|
|
|
|
|
|
|
| @app.get("/")
|
| async def root():
|
| """Root endpoint - shows server status"""
|
| return {
|
| "name": "Florence-2 ONNX Image Captioning Server",
|
| "status": "running",
|
| "model": "Florence-2-base-ft-quantized-onnx",
|
| "model_loaded": engine is not None,
|
| "endpoints": {
|
| "GET /health": "Health check",
|
| "GET /analyze": "Analyze image from URL",
|
| "POST /analyze": "Analyze uploaded image",
|
| }
|
| }
|
|
|
| @app.get("/health")
|
| async def health():
|
| """Health check endpoint"""
|
| return {
|
| "status": "healthy" if engine is not None else "initializing",
|
| "model": "Florence-2-base-ft-quantized-onnx",
|
| "model_loaded": engine is not None,
|
| }
|
|
|
| @app.get("/analyze")
|
| async def analyze_get(image_url: str = None):
|
| """Analyze an image by URL.
|
|
|
| Usage: /analyze?image_url=https://example.com/image.jpg
|
| """
|
| try:
|
| if engine is None:
|
| raise HTTPException(status_code=503, detail="Model not initialized")
|
|
|
| if not image_url:
|
| raise HTTPException(status_code=400, detail="image_url query parameter is required")
|
|
|
|
|
| image = load_image_from_url(image_url)
|
|
|
|
|
| caption = engine.generate_caption(image_array=image)
|
|
|
| return JSONResponse(content={
|
| "success": True,
|
| "caption": caption,
|
| "image_size": {"width": image.width, "height": image.height},
|
| "model": "Florence-2-base-ft-quantized-onnx"
|
| })
|
|
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| return JSONResponse(
|
| status_code=500,
|
| content={"success": False, "error": str(e)}
|
| )
|
|
|
| @app.post("/analyze")
|
| async def analyze_post(file: UploadFile = File(None)):
|
| """Analyze an uploaded image (multipart/form-data).
|
|
|
| Returns: JSON with caption and metadata
|
| """
|
| try:
|
| if engine is None:
|
| raise HTTPException(status_code=503, detail="Model not initialized")
|
|
|
| if file is None:
|
| raise HTTPException(status_code=400, detail="file is required")
|
|
|
|
|
| content = await file.read()
|
|
|
|
|
| try:
|
| image = load_image_from_bytes(content)
|
| except Exception as e:
|
| raise HTTPException(status_code=400, detail=f"Failed to read uploaded image: {e}")
|
|
|
|
|
| caption = engine.generate_caption(image_array=image)
|
|
|
| return JSONResponse(content={
|
| "success": True,
|
| "caption": caption,
|
| "filename": file.filename,
|
| "image_size": {"width": image.width, "height": image.height},
|
| "model": "Florence-2-base-ft-quantized-onnx"
|
| })
|
|
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| return JSONResponse(
|
| status_code=500,
|
| content={"success": False, "error": str(e)}
|
| )
|
|
|
|
|
| port = int(os.environ.get("PORT", 7860))
|
|
|
|
|
| if __name__ == "__main__":
|
| print("Initializing Florence-2 ONNX Engine...")
|
| initialize_engine()
|
|
|
| print(f"\n✓ Server ready! Starting on 0.0.0.0:{port}")
|
| print(f"API Documentation: http://localhost:{port}/docs")
|
|
|
| uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|