import os import sys import time import subprocess import numpy as np from PIL import Image from io import BytesIO import requests import threading # FastAPI imports from fastapi import FastAPI, UploadFile, File, HTTPException, Form from fastapi.responses import JSONResponse import uvicorn # 1. Environment Setup & Dependency Installation def setup_environment(): print("--- Setting up environment ---") dependencies = ["huggingface_hub", "onnxruntime", "transformers", "pillow", "numpy"] try: import huggingface_hub import onnxruntime import transformers print("Dependencies already satisfied.") except ImportError: print("Installing dependencies...") subprocess.check_call([sys.executable, "-m", "pip", "install"] + dependencies) # 2. Model Download def download_model(repo_id="Heliosoph/florence-2-base-ft-quantized-onnx", local_dir="florence2_quantized"): from huggingface_hub import snapshot_download if not os.path.exists(local_dir): print(f"--- Downloading model from {repo_id} ---") snapshot_download(repo_id=repo_id, local_dir=local_dir) print("Download complete.") else: print(f"Model directory '{local_dir}' already exists.") # 3. Inference Engine class Florence2ONNXEngine: def __init__(self, model_dir="florence2_quantized"): import onnxruntime as ort from transformers import CLIPImageProcessor, AutoTokenizer self.model_dir = model_dir print("--- Initializing ONNX Engine ---") # Load processors self.image_processor = CLIPImageProcessor.from_pretrained("microsoft/Florence-2-base-ft") self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") # Load ONNX sessions providers = ['CPUExecutionProvider'] self.vision_session = ort.InferenceSession(os.path.join(model_dir, 'vision_encoder_quantized.onnx'), providers=providers) self.embed_session = ort.InferenceSession(os.path.join(model_dir, 'embed_tokens_quantized.onnx'), providers=providers) self.encoder_session = ort.InferenceSession(os.path.join(model_dir, 'encoder_model_quantized.onnx'), providers=providers) self.decoder_session = ort.InferenceSession(os.path.join(model_dir, 'decoder_model_quantized.onnx'), providers=providers) print("āœ“ Florence-2 ONNX Engine initialized successfully") def generate_caption(self, image_path=None, image_array=None, task_prompt="", max_new_tokens=1024): """Generate caption from image path or PIL Image object""" if image_path: image = Image.open(image_path).convert("RGB") elif image_array is not None and isinstance(image_array, Image.Image): image = image_array.convert("RGB") else: raise ValueError("Either image_path or image_array must be provided") print(f"--- Running Inference (Max Tokens: {max_new_tokens}) ---") pixel_values = self.image_processor(images=image, return_tensors="np")['pixel_values'] # Map specific prompts to descriptive strings if needed prompt_map = { "": "What does the image describe?", "": "Describe this image in detail.", "": "Describe this image in great detail with every object and background." } text_prompt = prompt_map.get(task_prompt, task_prompt) input_ids = self.tokenizer(text_prompt, return_tensors="np")['input_ids'] # 1. Vision Features start_time = time.time() image_features = self.vision_session.run(None, {'pixel_values': pixel_values})[0] # 2. Text Embeddings text_embeds = self.embed_session.run(None, {'input_ids': input_ids})[0] # 3. Encoder Fusion combined_embeds = np.concatenate([image_features, text_embeds], axis=1) attention_mask = np.ones((1, combined_embeds.shape[1]), dtype=np.int64) encoder_outputs = self.encoder_session.run(None, { 'inputs_embeds': combined_embeds, 'attention_mask': attention_mask }) last_hidden_state = encoder_outputs[0] # 4. Autoregressive Decoding with Repetition Penalty generated_ids = [2] # BART Start Token min_new_tokens = 250 # Enforce minimum generation repetition_penalty = 1.5 # Penalize repeated tokens for i in range(max_new_tokens): decoder_input_ids = np.array([generated_ids], dtype=np.int64) decoder_embeds = self.embed_session.run(None, {'input_ids': decoder_input_ids})[0] logits = self.decoder_session.run(None, { 'inputs_embeds': decoder_embeds, 'encoder_hidden_states': last_hidden_state, 'encoder_attention_mask': attention_mask })[0] # Apply repetition penalty to recently generated tokens current_logits = logits[0, -1, :].copy() for prev_token in set(generated_ids[-50:]): # Check last 50 tokens if current_logits[prev_token] > 0: current_logits[prev_token] /= repetition_penalty else: current_logits[prev_token] *= repetition_penalty next_token = np.argmax(current_logits) # Only allow EOS token after minimum generation if next_token == 2 and i < min_new_tokens: # Force a different token by reducing EOS probability current_logits[2] = -1e9 next_token = np.argmax(current_logits) if next_token == 2: break # EOS Token generated_ids.append(next_token) if (i + 1) % 50 == 0: print(f"Generated {i+1} tokens...") end_time = time.time() caption = self.tokenizer.decode(generated_ids, skip_special_tokens=True) print(f"Inference complete in {end_time - start_time:.2f}s") return caption # Global engine instance engine = None def initialize_engine(): """Initialize the Florence2 ONNX engine""" global engine setup_environment() download_model() engine = Florence2ONNXEngine() # FastAPI app setup app = FastAPI( title="Florence-2 ONNX Image Captioning Server", description="Auto-captions images using Florence-2 ONNX models" ) def load_image_from_url(image_url: str) -> Image.Image: """Load an image from a URL.""" try: response = requests.get(image_url, timeout=30) response.raise_for_status() image = Image.open(BytesIO(response.content)) return image.convert('RGB') except Exception as e: raise ValueError(f"Error loading image from URL: {e}") def load_image_from_bytes(image_bytes: bytes) -> Image.Image: """Load an image from bytes.""" try: image = Image.open(BytesIO(image_bytes)) return image.convert('RGB') except Exception as e: raise ValueError(f"Error loading image from bytes: {e}") # API Endpoints @app.get("/") async def root(): """Root endpoint - shows server status""" return { "name": "Florence-2 ONNX Image Captioning Server", "status": "running", "model": "Florence-2-base-ft-quantized-onnx", "model_loaded": engine is not None, "endpoints": { "GET /health": "Health check", "GET /analyze": "Analyze image from URL", "POST /analyze": "Analyze uploaded image", } } @app.get("/health") async def health(): """Health check endpoint""" return { "status": "healthy" if engine is not None else "initializing", "model": "Florence-2-base-ft-quantized-onnx", "model_loaded": engine is not None, } @app.get("/analyze") async def analyze_get(image_url: str = None): """Analyze an image by URL. Usage: /analyze?image_url=https://example.com/image.jpg """ try: if engine is None: raise HTTPException(status_code=503, detail="Model not initialized") if not image_url: raise HTTPException(status_code=400, detail="image_url query parameter is required") # Load image from URL image = load_image_from_url(image_url) # Generate caption caption = engine.generate_caption(image_array=image) return JSONResponse(content={ "success": True, "caption": caption, "image_size": {"width": image.width, "height": image.height}, "model": "Florence-2-base-ft-quantized-onnx" }) except HTTPException: raise except Exception as e: return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) @app.post("/analyze") async def analyze_post(file: UploadFile = File(None)): """Analyze an uploaded image (multipart/form-data). Returns: JSON with caption and metadata """ try: if engine is None: raise HTTPException(status_code=503, detail="Model not initialized") if file is None: raise HTTPException(status_code=400, detail="file is required") # Read uploaded file content = await file.read() # Load image from bytes try: image = load_image_from_bytes(content) except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to read uploaded image: {e}") # Generate caption caption = engine.generate_caption(image_array=image) return JSONResponse(content={ "success": True, "caption": caption, "filename": file.filename, "image_size": {"width": image.width, "height": image.height}, "model": "Florence-2-base-ft-quantized-onnx" }) except HTTPException: raise except Exception as e: return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) # Get the port from environment variable port = int(os.environ.get("PORT", 7860)) # Launch server if __name__ == "__main__": print("Initializing Florence-2 ONNX Engine...") initialize_engine() print(f"\nāœ“ Server ready! Starting on 0.0.0.0:{port}") print(f"API Documentation: http://localhost:{port}/docs") uvicorn.run(app, host="0.0.0.0", port=port)