File size: 6,817 Bytes
1e91d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
Model Manager - Handles loading and caching of YOLO and VLM models
"""

import torch
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig
)
from ultralytics import YOLO
import os
from typing import Tuple

from config import (
    YOLO_MODEL_PATH,
    VLM_MODEL_ID,
    QUANTIZATION_CONFIG,
    YOLO_CONFIDENCE_THRESHOLD
)


class ModelManager:
    """Singleton class to manage model loading and inference"""
    
    _instance = None
    _initialized = False
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelManager, cls).__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not ModelManager._initialized:
            self.yolo_model = None
            self.vlm_model = None
            self.processor = None
            ModelManager._initialized = True
    
    def load_models(self):
        """Load both YOLO and VLM models into memory"""
        print("πŸš€ Starting model loading...")
        
        # Load YOLO model
        self.yolo_model = self._load_yolo_model()
        
        # Load VLM model
        self.vlm_model, self.processor = self._load_vlm_model()
        
        # Warm up models to initialize CUDA context
        self._warmup_models()
        
        print("βœ… All models loaded successfully!")
    
    def _load_yolo_model(self) -> YOLO:
        """Load trained YOLO model for signature and stamp detection"""
        if not os.path.exists(YOLO_MODEL_PATH):
            raise FileNotFoundError(
                f"YOLO model not found at {YOLO_MODEL_PATH}. "
                "Please ensure best.pt is in utils/models/"
            )
        
        yolo_model = YOLO(str(YOLO_MODEL_PATH))
        print(f"βœ… YOLO model loaded from {YOLO_MODEL_PATH}")
        return yolo_model
    
    def _load_vlm_model(self) -> Tuple:
        """
        Load Qwen2.5-VL model with 4-bit quantization
        Downloads from Hugging Face on first run
        """
        print(f"πŸ“₯ Loading VLM model: {VLM_MODEL_ID}")
        print("   (This will download ~4GB on first run)")
        
        # Configure 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=QUANTIZATION_CONFIG["load_in_4bit"],
            bnb_4bit_quant_type=QUANTIZATION_CONFIG["bnb_4bit_quant_type"],
            bnb_4bit_compute_dtype=getattr(torch, QUANTIZATION_CONFIG["bnb_4bit_compute_dtype"]),
            bnb_4bit_use_double_quant=QUANTIZATION_CONFIG["bnb_4bit_use_double_quant"]
        )
        
        # Load processor
        processor = AutoProcessor.from_pretrained(
            VLM_MODEL_ID,
            trust_remote_code=True
        )
        
        # Load model with quantization
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            VLM_MODEL_ID,
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        )
        
        model.eval()
        print(f"βœ… Qwen2.5-VL model loaded successfully")
        
        return model, processor
    
    def _warmup_models(self):
        """Warm up models with a dummy inference to initialize CUDA context"""
        print("πŸ”₯ Warming up models (initializing CUDA context)...")
        import time
        from PIL import Image
        import numpy as np
        
        warmup_start = time.time()
        
        # Create a small dummy image
        dummy_image = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 255)
        
        try:
            # Warm up VLM
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": dummy_image},
                        {"type": "text", "text": "warm up"}
                    ]
                }
            ]
            
            from qwen_vl_utils import process_vision_info
            text = self.processor.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")
            
            # Run a quick inference
            with torch.no_grad():
                _ = self.vlm_model.generate(**inputs, max_new_tokens=5)
            
            # Clean up
            del inputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            warmup_time = time.time() - warmup_start
            print(f"βœ… Models warmed up in {warmup_time:.2f}s (CUDA context initialized)")
            
        except Exception as e:
            print(f"⚠️ Warmup failed (non-critical): {e}")
    
    def detect_sign_stamp(self, image_path: str):
        """
        Detect signature and stamp in the image using YOLO
        
        Returns:
            tuple: (signature_info, stamp_info, signature_conf, stamp_conf)
        """
        if self.yolo_model is None:
            raise RuntimeError("YOLO model not loaded. Call load_models() first.")
        
        results = self.yolo_model(image_path, verbose=False)[0]
        
        signature_info = {"present": False, "bbox": None}
        stamp_info = {"present": False, "bbox": None}
        signature_conf = 0.0
        stamp_conf = 0.0
        
        if results.boxes is not None:
            for box in results.boxes:
                cls_id = int(box.cls[0])
                conf = float(box.conf[0])
                
                if conf > YOLO_CONFIDENCE_THRESHOLD:
                    bbox = box.xyxy[0].cpu().numpy().tolist()
                    bbox = [int(coord) for coord in bbox]
                    
                    # Class 0: signature, Class 1: stamp
                    if cls_id == 0 and conf > signature_conf:
                        signature_info = {"present": True, "bbox": bbox}
                        signature_conf = conf
                    elif cls_id == 1 and conf > stamp_conf:
                        stamp_info = {"present": True, "bbox": bbox}
                        stamp_conf = conf
        
        return signature_info, stamp_info, signature_conf, stamp_conf
    
    def is_loaded(self) -> bool:
        """Check if models are loaded"""
        return (self.yolo_model is not None and 
                self.vlm_model is not None and 
                self.processor is not None)


# Global model manager instance
model_manager = ModelManager()