Spaces:
Sleeping
Sleeping
File size: 6,135 Bytes
a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a 3dc870c d03c47c 29f706c d03c47c a1c932a beccd45 d03c47c 29f706c d03c47c 29f706c beccd45 29f706c d03c47c 29f706c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a beccd45 a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a d03c47c a1c932a beccd45 d03c47c beccd45 a1c932a beccd45 d03c47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import base64
import io
import os
from PIL import Image
import torch
import numpy as np
# Import your custom utility functions
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
# Import YOLO from ultralytics and transformers for captioning
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM
# ---------------------------------------------------------------------------
# Load the YOLO model
# ---------------------------------------------------------------------------
try:
yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"]
yolo_model = yolo_model.to("cuda")
except Exception as e:
print("Error loading YOLO model on CUDA:", e)
yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
print(f"YOLO model type: {type(yolo_model)}")
# ---------------------------------------------------------------------------
# Load the captioning model (Florence-2)
# ---------------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load the processor for the Florence-2 model
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=dtype,
trust_remote_code=True
).to(device)
except Exception as e:
print(f"Error loading caption model: {str(e)}")
# Fallback to CPU with float32
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float32,
trust_remote_code=True
).to("cpu")
# Force configuration for DaViT vision tower if missing
if not hasattr(model.config, 'vision_config'):
model.config.vision_config = {}
if 'model_type' not in model.config.vision_config:
model.config.vision_config['model_type'] = 'davit'
caption_model_processor = {"processor": processor, "model": model}
print("Finish loading caption model!")
# ---------------------------------------------------------------------------
# Create FastAPI application and response model
# ---------------------------------------------------------------------------
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
# ---------------------------------------------------------------------------
# Main processing function
# ---------------------------------------------------------------------------
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
# Save the input image temporarily
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
# Open the saved image for processing
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200 # adjust scaling factor as needed
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
# Run OCR to get text and OCR bounding boxes
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
# Run YOLO and semantic processing to get the labeled image and bounding boxes
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Decode the base64-encoded image output from get_som_labeled_img
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print("Finish processing")
parsed_content_list_str = "\n".join(parsed_content_list)
# Encode final image to base64 string for response
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
# ---------------------------------------------------------------------------
# FastAPI endpoint for image processing
# ---------------------------------------------------------------------------
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
# Debug logging for file information
print(f"Processing image: {image_file.filename}")
print(f"Image size: {image_input.size}")
response = process(image_input, box_threshold, iou_threshold)
# Validate response
if not response.image:
raise ValueError("Empty image in response")
return response
except Exception as e:
import traceback
traceback.print_exc() # Print full traceback for debugging
raise HTTPException(status_code=500, detail=str(e))
|