potato18 / app.py
makeitfr's picture
Upload app.py with huggingface_hub
0f04363 verified
Raw
History Blame Contribute Delete
11 kB
import os
import sys
import time
import subprocess
import numpy as np
from PIL import Image
from io import BytesIO
import requests
import threading
# FastAPI imports
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import JSONResponse
import uvicorn
# 1. Environment Setup & Dependency Installation
def setup_environment():
print("--- Setting up environment ---")
dependencies = ["huggingface_hub", "onnxruntime", "transformers", "pillow", "numpy"]
try:
import huggingface_hub
import onnxruntime
import transformers
print("Dependencies already satisfied.")
except ImportError:
print("Installing dependencies...")
subprocess.check_call([sys.executable, "-m", "pip", "install"] + dependencies)
# 2. Model Download
def download_model(repo_id="Heliosoph/florence-2-base-ft-quantized-onnx", local_dir="florence2_quantized"):
from huggingface_hub import snapshot_download
if not os.path.exists(local_dir):
print(f"--- Downloading model from {repo_id} ---")
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("Download complete.")
else:
print(f"Model directory '{local_dir}' already exists.")
# 3. Inference Engine
class Florence2ONNXEngine:
def __init__(self, model_dir="florence2_quantized"):
import onnxruntime as ort
from transformers import CLIPImageProcessor, AutoTokenizer
self.model_dir = model_dir
print("--- Initializing ONNX Engine ---")
# Load processors
self.image_processor = CLIPImageProcessor.from_pretrained("microsoft/Florence-2-base-ft")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
# Load ONNX sessions
providers = ['CPUExecutionProvider']
self.vision_session = ort.InferenceSession(os.path.join(model_dir, 'vision_encoder_quantized.onnx'), providers=providers)
self.embed_session = ort.InferenceSession(os.path.join(model_dir, 'embed_tokens_quantized.onnx'), providers=providers)
self.encoder_session = ort.InferenceSession(os.path.join(model_dir, 'encoder_model_quantized.onnx'), providers=providers)
self.decoder_session = ort.InferenceSession(os.path.join(model_dir, 'decoder_model_quantized.onnx'), providers=providers)
print("✓ Florence-2 ONNX Engine initialized successfully")
def generate_caption(self, image_path=None, image_array=None, task_prompt="<MORE_DETAILED_CAPTION>", max_new_tokens=1024):
"""Generate caption from image path or PIL Image object"""
if image_path:
image = Image.open(image_path).convert("RGB")
elif image_array is not None and isinstance(image_array, Image.Image):
image = image_array.convert("RGB")
else:
raise ValueError("Either image_path or image_array must be provided")
print(f"--- Running Inference (Max Tokens: {max_new_tokens}) ---")
pixel_values = self.image_processor(images=image, return_tensors="np")['pixel_values']
# Map specific prompts to descriptive strings if needed
prompt_map = {
"<CAPTION>": "What does the image describe?",
"<DETAILED_CAPTION>": "Describe this image in detail.",
"<MORE_DETAILED_CAPTION>": "Describe this image in great detail with every object and background."
}
text_prompt = prompt_map.get(task_prompt, task_prompt)
input_ids = self.tokenizer(text_prompt, return_tensors="np")['input_ids']
# 1. Vision Features
start_time = time.time()
image_features = self.vision_session.run(None, {'pixel_values': pixel_values})[0]
# 2. Text Embeddings
text_embeds = self.embed_session.run(None, {'input_ids': input_ids})[0]
# 3. Encoder Fusion
combined_embeds = np.concatenate([image_features, text_embeds], axis=1)
attention_mask = np.ones((1, combined_embeds.shape[1]), dtype=np.int64)
encoder_outputs = self.encoder_session.run(None, {
'inputs_embeds': combined_embeds,
'attention_mask': attention_mask
})
last_hidden_state = encoder_outputs[0]
# 4. Autoregressive Decoding with Repetition Penalty
generated_ids = [2] # BART Start Token
min_new_tokens = 250 # Enforce minimum generation
repetition_penalty = 1.5 # Penalize repeated tokens
for i in range(max_new_tokens):
decoder_input_ids = np.array([generated_ids], dtype=np.int64)
decoder_embeds = self.embed_session.run(None, {'input_ids': decoder_input_ids})[0]
logits = self.decoder_session.run(None, {
'inputs_embeds': decoder_embeds,
'encoder_hidden_states': last_hidden_state,
'encoder_attention_mask': attention_mask
})[0]
# Apply repetition penalty to recently generated tokens
current_logits = logits[0, -1, :].copy()
for prev_token in set(generated_ids[-50:]): # Check last 50 tokens
if current_logits[prev_token] > 0:
current_logits[prev_token] /= repetition_penalty
else:
current_logits[prev_token] *= repetition_penalty
next_token = np.argmax(current_logits)
# Only allow EOS token after minimum generation
if next_token == 2 and i < min_new_tokens:
# Force a different token by reducing EOS probability
current_logits[2] = -1e9
next_token = np.argmax(current_logits)
if next_token == 2: break # EOS Token
generated_ids.append(next_token)
if (i + 1) % 50 == 0:
print(f"Generated {i+1} tokens...")
end_time = time.time()
caption = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"Inference complete in {end_time - start_time:.2f}s")
return caption
# Global engine instance
engine = None
def initialize_engine():
"""Initialize the Florence2 ONNX engine"""
global engine
setup_environment()
download_model()
engine = Florence2ONNXEngine()
# FastAPI app setup
app = FastAPI(
title="Florence-2 ONNX Image Captioning Server",
description="Auto-captions images using Florence-2 ONNX models"
)
def load_image_from_url(image_url: str) -> Image.Image:
"""Load an image from a URL."""
try:
response = requests.get(image_url, timeout=30)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image.convert('RGB')
except Exception as e:
raise ValueError(f"Error loading image from URL: {e}")
def load_image_from_bytes(image_bytes: bytes) -> Image.Image:
"""Load an image from bytes."""
try:
image = Image.open(BytesIO(image_bytes))
return image.convert('RGB')
except Exception as e:
raise ValueError(f"Error loading image from bytes: {e}")
# API Endpoints
@app.get("/")
async def root():
"""Root endpoint - shows server status"""
return {
"name": "Florence-2 ONNX Image Captioning Server",
"status": "running",
"model": "Florence-2-base-ft-quantized-onnx",
"model_loaded": engine is not None,
"endpoints": {
"GET /health": "Health check",
"GET /analyze": "Analyze image from URL",
"POST /analyze": "Analyze uploaded image",
}
}
@app.get("/health")
async def health():
"""Health check endpoint"""
return {
"status": "healthy" if engine is not None else "initializing",
"model": "Florence-2-base-ft-quantized-onnx",
"model_loaded": engine is not None,
}
@app.get("/analyze")
async def analyze_get(image_url: str = None):
"""Analyze an image by URL.
Usage: /analyze?image_url=https://example.com/image.jpg
"""
try:
if engine is None:
raise HTTPException(status_code=503, detail="Model not initialized")
if not image_url:
raise HTTPException(status_code=400, detail="image_url query parameter is required")
# Load image from URL
image = load_image_from_url(image_url)
# Generate caption
caption = engine.generate_caption(image_array=image)
return JSONResponse(content={
"success": True,
"caption": caption,
"image_size": {"width": image.width, "height": image.height},
"model": "Florence-2-base-ft-quantized-onnx"
})
except HTTPException:
raise
except Exception as e:
return JSONResponse(
status_code=500,
content={"success": False, "error": str(e)}
)
@app.post("/analyze")
async def analyze_post(file: UploadFile = File(None)):
"""Analyze an uploaded image (multipart/form-data).
Returns: JSON with caption and metadata
"""
try:
if engine is None:
raise HTTPException(status_code=503, detail="Model not initialized")
if file is None:
raise HTTPException(status_code=400, detail="file is required")
# Read uploaded file
content = await file.read()
# Load image from bytes
try:
image = load_image_from_bytes(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to read uploaded image: {e}")
# Generate caption
caption = engine.generate_caption(image_array=image)
return JSONResponse(content={
"success": True,
"caption": caption,
"filename": file.filename,
"image_size": {"width": image.width, "height": image.height},
"model": "Florence-2-base-ft-quantized-onnx"
})
except HTTPException:
raise
except Exception as e:
return JSONResponse(
status_code=500,
content={"success": False, "error": str(e)}
)
# Get the port from environment variable
port = int(os.environ.get("PORT", 7860))
# Launch server
if __name__ == "__main__":
print("Initializing Florence-2 ONNX Engine...")
initialize_engine()
print(f"\n✓ Server ready! Starting on 0.0.0.0:{port}")
print(f"API Documentation: http://localhost:{port}/docs")
uvicorn.run(app, host="0.0.0.0", port=port)