omniapi / main.py
banao-tech's picture
Update main.py
d03c47c verified
raw
history blame
6.14 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import base64
import io
import os
from PIL import Image
import torch
import numpy as np
# Import your custom utility functions
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
# Import YOLO from ultralytics and transformers for captioning
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM
# ---------------------------------------------------------------------------
# Load the YOLO model
# ---------------------------------------------------------------------------
try:
yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"]
yolo_model = yolo_model.to("cuda")
except Exception as e:
print("Error loading YOLO model on CUDA:", e)
yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
print(f"YOLO model type: {type(yolo_model)}")
# ---------------------------------------------------------------------------
# Load the captioning model (Florence-2)
# ---------------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load the processor for the Florence-2 model
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=dtype,
trust_remote_code=True
).to(device)
except Exception as e:
print(f"Error loading caption model: {str(e)}")
# Fallback to CPU with float32
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float32,
trust_remote_code=True
).to("cpu")
# Force configuration for DaViT vision tower if missing
if not hasattr(model.config, 'vision_config'):
model.config.vision_config = {}
if 'model_type' not in model.config.vision_config:
model.config.vision_config['model_type'] = 'davit'
caption_model_processor = {"processor": processor, "model": model}
print("Finish loading caption model!")
# ---------------------------------------------------------------------------
# Create FastAPI application and response model
# ---------------------------------------------------------------------------
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
# ---------------------------------------------------------------------------
# Main processing function
# ---------------------------------------------------------------------------
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
# Save the input image temporarily
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
# Open the saved image for processing
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200 # adjust scaling factor as needed
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
# Run OCR to get text and OCR bounding boxes
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
# Run YOLO and semantic processing to get the labeled image and bounding boxes
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Decode the base64-encoded image output from get_som_labeled_img
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print("Finish processing")
parsed_content_list_str = "\n".join(parsed_content_list)
# Encode final image to base64 string for response
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
# ---------------------------------------------------------------------------
# FastAPI endpoint for image processing
# ---------------------------------------------------------------------------
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
# Debug logging for file information
print(f"Processing image: {image_file.filename}")
print(f"Image size: {image_input.size}")
response = process(image_input, box_threshold, iou_threshold)
# Validate response
if not response.image:
raise ValueError("Empty image in response")
return response
except Exception as e:
import traceback
traceback.print_exc() # Print full traceback for debugging
raise HTTPException(status_code=500, detail=str(e))