File size: 6,135 Bytes
a1c932a
 
 
 
 
 
 
 
d03c47c
 
a1c932a
d03c47c
a1c932a
 
 
 
 
 
 
d03c47c
a1c932a
d03c47c
a1c932a
d03c47c
 
 
a1c932a
 
 
d03c47c
 
a1c932a
 
3dc870c
 
d03c47c
 
 
29f706c
d03c47c
 
 
 
a1c932a
beccd45
 
 
d03c47c
29f706c
 
 
d03c47c
29f706c
beccd45
 
29f706c
 
 
 
d03c47c
29f706c
 
 
 
 
a1c932a
d03c47c
a1c932a
d03c47c
 
 
a1c932a
 
 
 
 
 
 
d03c47c
 
 
 
 
a1c932a
beccd45
a1c932a
d03c47c
 
a1c932a
d03c47c
a1c932a
 
 
 
 
 
 
d03c47c
a1c932a
 
 
 
 
 
 
 
 
d03c47c
 
a1c932a
 
 
 
 
 
 
 
 
 
 
d03c47c
 
a1c932a
d03c47c
a1c932a
 
d03c47c
a1c932a
 
 
 
 
 
d03c47c
a1c932a
 
 
d03c47c
 
 
a1c932a
 
 
 
 
 
 
 
 
beccd45
d03c47c
beccd45
 
 
 
 
 
 
 
 
 
 
a1c932a
beccd45
d03c47c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import base64
import io
import os

from PIL import Image
import torch
import numpy as np

# Import your custom utility functions
from utils import (
    check_ocr_box,
    get_yolo_model,
    get_caption_model_processor,
    get_som_labeled_img,
)

# Import YOLO from ultralytics and transformers for captioning
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM

# ---------------------------------------------------------------------------
# Load the YOLO model
# ---------------------------------------------------------------------------
try:
    yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"]
    yolo_model = yolo_model.to("cuda")
except Exception as e:
    print("Error loading YOLO model on CUDA:", e)
    yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]

print(f"YOLO model type: {type(yolo_model)}")

# ---------------------------------------------------------------------------
# Load the captioning model (Florence-2)
# ---------------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# Load the processor for the Florence-2 model
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

try:
    model = AutoModelForCausalLM.from_pretrained(
        "weights/icon_caption_florence",
        torch_dtype=dtype,
        trust_remote_code=True
    ).to(device)
except Exception as e:
    print(f"Error loading caption model: {str(e)}")
    # Fallback to CPU with float32
    model = AutoModelForCausalLM.from_pretrained(
        "weights/icon_caption_florence",
        torch_dtype=torch.float32,
        trust_remote_code=True
    ).to("cpu")

# Force configuration for DaViT vision tower if missing
if not hasattr(model.config, 'vision_config'):
    model.config.vision_config = {}
if 'model_type' not in model.config.vision_config:
    model.config.vision_config['model_type'] = 'davit'

caption_model_processor = {"processor": processor, "model": model}
print("Finish loading caption model!")

# ---------------------------------------------------------------------------
# Create FastAPI application and response model
# ---------------------------------------------------------------------------
app = FastAPI()

class ProcessResponse(BaseModel):
    image: str  # Base64 encoded image
    parsed_content_list: str
    label_coordinates: str

# ---------------------------------------------------------------------------
# Main processing function
# ---------------------------------------------------------------------------
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
    # Save the input image temporarily
    image_save_path = "imgs/saved_image_demo.png"
    os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
    image_input.save(image_save_path)
    
    # Open the saved image for processing
    image = Image.open(image_save_path)
    box_overlay_ratio = image.size[0] / 3200  # adjust scaling factor as needed
    draw_bbox_config = {
        "text_scale": 0.8 * box_overlay_ratio,
        "text_thickness": max(int(2 * box_overlay_ratio), 1),
        "text_padding": max(int(3 * box_overlay_ratio), 1),
        "thickness": max(int(3 * box_overlay_ratio), 1),
    }

    # Run OCR to get text and OCR bounding boxes
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
        image_save_path,
        display_img=False,
        output_bb_format="xyxy",
        goal_filtering=None,
        easyocr_args={"paragraph": False, "text_threshold": 0.9},
        use_paddleocr=True,
    )
    text, ocr_bbox = ocr_bbox_rslt

    # Run YOLO and semantic processing to get the labeled image and bounding boxes
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
        image_save_path,
        yolo_model,
        BOX_TRESHOLD=box_threshold,
        output_coord_in_ratio=True,
        ocr_bbox=ocr_bbox,
        draw_bbox_config=draw_bbox_config,
        caption_model_processor=caption_model_processor,
        ocr_text=text,
        iou_threshold=iou_threshold,
    )
    
    # Decode the base64-encoded image output from get_som_labeled_img
    image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
    print("Finish processing")
    parsed_content_list_str = "\n".join(parsed_content_list)

    # Encode final image to base64 string for response
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return ProcessResponse(
        image=img_str,
        parsed_content_list=parsed_content_list_str,
        label_coordinates=str(label_coordinates),
    )

# ---------------------------------------------------------------------------
# FastAPI endpoint for image processing
# ---------------------------------------------------------------------------
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
    image_file: UploadFile = File(...),
    box_threshold: float = 0.05,
    iou_threshold: float = 0.1,
):
    try:
        contents = await image_file.read()
        image_input = Image.open(io.BytesIO(contents)).convert("RGB")
        
        # Debug logging for file information
        print(f"Processing image: {image_file.filename}")
        print(f"Image size: {image_input.size}")
        
        response = process(image_input, box_threshold, iou_threshold)
        
        # Validate response
        if not response.image:
            raise ValueError("Empty image in response")
            
        return response
        
    except Exception as e:
        import traceback
        traceback.print_exc()  # Print full traceback for debugging
        raise HTTPException(status_code=500, detail=str(e))