github-actions[bot] commited on
Commit
1e91d4e
Β·
1 Parent(s): 1b8dcf1

Sync from GitHub: 0326ea25edafa877b6e50d9380e8b84ad62476c1

Browse files
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
4
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
5
  *.onnx filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
6
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
.github/workflows/push_to_huggingface.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Push to Hugging Face Hub
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - master
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Checkout repository
14
+ uses: actions/checkout@v3
15
+ with:
16
+ fetch-depth: 0
17
+ lfs: true
18
+
19
+ - name: Setup Python
20
+ uses: actions/setup-python@v4
21
+ with:
22
+ python-version: '3.10'
23
+
24
+ - name: Push to Hugging Face Hub
25
+ env:
26
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
27
+ HF_SPACE_NAME: ${{ secrets.HF_SPACE_NAME }}
28
+ run: |
29
+ git config --global user.email "github-actions[bot]@users.noreply.github.com"
30
+ git config --global user.name "github-actions[bot]"
31
+
32
+ # Install git-lfs
33
+ git lfs install
34
+
35
+ # Clone the HF space or create new directory
36
+ git clone https://user:$HF_TOKEN@huggingface.co/spaces/$HF_SPACE_NAME hf_space 2>/dev/null || {
37
+ mkdir hf_space
38
+ cd hf_space
39
+ git init
40
+ git remote add origin https://user:$HF_TOKEN@huggingface.co/spaces/$HF_SPACE_NAME
41
+ cd ..
42
+ }
43
+
44
+ cd hf_space
45
+
46
+ # Configure git LFS
47
+ git lfs install
48
+ git lfs track "*.pt"
49
+ git lfs track "*.pth"
50
+ git lfs track "*.bin"
51
+ git lfs track "*.h5"
52
+ git lfs track "*.onnx"
53
+ git lfs track "*.safetensors"
54
+
55
+ # Copy files from the repository (excluding .git and hf_space)
56
+ rsync -av --exclude='.git' --exclude='hf_space' ../ .
57
+
58
+ # Add all files and commit
59
+ git add .
60
+ git diff-index --quiet HEAD || git commit -m "Sync from GitHub: ${{ github.sha }}"
61
+
62
+ # Push to Hugging Face
63
+ git push origin main --force
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+ .env
12
+ .venv
13
+ venv/
14
+ ENV/
15
+ env/
16
+ .vscode/
17
+ .idea/
18
+ *.swp
19
+ *.swo
20
+ *~
21
+ .DS_Store
22
+ sample_output/
23
+ *.log
24
+ .pytest_cache/
25
+ .coverage
26
+ htmlcov/
27
+ .mypy_cache/
28
+ .ipynb_checkpoints/
29
+
30
+ *.md
31
+ !README_HF.md
32
+ !README.md
33
+ test*
34
+ executable.py
README.md CHANGED
@@ -48,7 +48,3 @@ print(response.json())
48
  curl -X POST "https://YOUR_USERNAME-invoice-extractor.hf.space/extract" \
49
  -F "file=@invoice.png"
50
  ```
51
-
52
- ## Hardware
53
-
54
- Requires GPU: T4 minimum (8GB VRAM recommended)
 
48
  curl -X POST "https://YOUR_USERNAME-invoice-extractor.hf.space/extract" \
49
  -F "file=@invoice.png"
50
  ```
 
 
 
 
app.py CHANGED
@@ -140,14 +140,18 @@ async def extract_invoice(
140
  )
141
 
142
  # Save uploaded file to temporary location
 
 
143
  temp_file = None
144
  try:
145
  # Create temporary file
 
146
  suffix = os.path.splitext(file.filename)[1]
147
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp:
148
  temp_file = temp.name
149
  # Write uploaded file content
150
  shutil.copyfileobj(file.file, temp)
 
151
 
152
  # Use filename as doc_id if not provided
153
  if doc_id is None:
@@ -156,6 +160,10 @@ async def extract_invoice(
156
  # Process invoice
157
  result = InferenceProcessor.process_invoice(temp_file, doc_id)
158
 
 
 
 
 
159
  return JSONResponse(content=result, media_type="application/json; charset=utf-8")
160
 
161
  except Exception as e:
 
140
  )
141
 
142
  # Save uploaded file to temporary location
143
+ import time
144
+ request_start = time.time()
145
  temp_file = None
146
  try:
147
  # Create temporary file
148
+ io_start = time.time()
149
  suffix = os.path.splitext(file.filename)[1]
150
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp:
151
  temp_file = temp.name
152
  # Write uploaded file content
153
  shutil.copyfileobj(file.file, temp)
154
+ io_time = round(time.time() - io_start, 3)
155
 
156
  # Use filename as doc_id if not provided
157
  if doc_id is None:
 
160
  # Process invoice
161
  result = InferenceProcessor.process_invoice(temp_file, doc_id)
162
 
163
+ # Add total request time (includes file I/O)
164
+ result['total_request_time_sec'] = round(time.time() - request_start, 2)
165
+ result['file_io_time_sec'] = io_time
166
+
167
  return JSONResponse(content=result, media_type="application/json; charset=utf-8")
168
 
169
  except Exception as e:
client_example.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example client script for Invoice Information Extractor API
3
+ Shows how to integrate the API into your application
4
+ """
5
+
6
+ import requests
7
+ from pathlib import Path
8
+ import json
9
+ from typing import List, Dict
10
+
11
+
12
+ class InvoiceExtractorClient:
13
+ """Client for Invoice Information Extractor API"""
14
+
15
+ def __init__(self, base_url: str = "http://localhost:7860"):
16
+ """
17
+ Initialize client
18
+
19
+ Args:
20
+ base_url: Base URL of the API (default: http://localhost:7860)
21
+ """
22
+ self.base_url = base_url.rstrip('/')
23
+ self.session = requests.Session()
24
+
25
+ def health_check(self) -> Dict:
26
+ """Check API health status"""
27
+ response = self.session.get(f"{self.base_url}/health")
28
+ response.raise_for_status()
29
+ return response.json()
30
+
31
+ def extract_invoice(self, image_path: str, doc_id: str = None) -> Dict:
32
+ """
33
+ Extract information from a single invoice
34
+
35
+ Args:
36
+ image_path: Path to invoice image
37
+ doc_id: Optional document identifier
38
+
39
+ Returns:
40
+ Extraction results as dictionary
41
+ """
42
+ with open(image_path, 'rb') as f:
43
+ files = {'file': f}
44
+ data = {'doc_id': doc_id} if doc_id else {}
45
+
46
+ response = self.session.post(
47
+ f"{self.base_url}/extract",
48
+ files=files,
49
+ data=data,
50
+ timeout=60
51
+ )
52
+ response.raise_for_status()
53
+ return response.json()
54
+
55
+ def extract_batch(self, image_paths: List[str]) -> List[Dict]:
56
+ """
57
+ Extract information from multiple invoices
58
+
59
+ Args:
60
+ image_paths: List of paths to invoice images
61
+
62
+ Returns:
63
+ List of extraction results
64
+ """
65
+ files = [('files', open(path, 'rb')) for path in image_paths]
66
+
67
+ try:
68
+ response = self.session.post(
69
+ f"{self.base_url}/extract_batch",
70
+ files=files,
71
+ timeout=120
72
+ )
73
+ response.raise_for_status()
74
+ return response.json()['results']
75
+ finally:
76
+ # Close all file handles
77
+ for _, file_handle in files:
78
+ file_handle.close()
79
+
80
+
81
+ # Example usage
82
+ if __name__ == "__main__":
83
+ # Initialize client
84
+ client = InvoiceExtractorClient("http://localhost:7860")
85
+
86
+ # Check health
87
+ print("Checking API health...")
88
+ health = client.health_check()
89
+ print(f"Status: {health['status']}")
90
+ print(f"Models loaded: {health['models_loaded']}\n")
91
+
92
+ # Example 1: Single invoice extraction
93
+ print("=" * 60)
94
+ print("Example 1: Single Invoice Extraction")
95
+ print("=" * 60)
96
+
97
+ # Replace with your invoice path
98
+ invoice_path = "sample_invoice.png"
99
+
100
+ if Path(invoice_path).exists():
101
+ try:
102
+ result = client.extract_invoice(invoice_path, doc_id="demo_001")
103
+
104
+ print(f"\nπŸ“„ Document ID: {result['doc_id']}")
105
+ print(f"βœ… Confidence: {result['confidence']}")
106
+ print(f"⏱️ Processing Time: {result['processing_time_sec']}s")
107
+ print(f"πŸ’° Cost Estimate: ${result['cost_estimate_usd']}")
108
+
109
+ print("\nπŸ“‹ Extracted Fields:")
110
+ fields = result['fields']
111
+ print(f" Dealer Name: {fields['dealer_name']}")
112
+ print(f" Model Name: {fields['model_name']}")
113
+ print(f" Horse Power: {fields['horse_power']} HP")
114
+ print(f" Asset Cost: β‚Ή{fields['asset_cost']:,}")
115
+ print(f" Signature: {'βœ“ Detected' if fields['signature']['present'] else 'βœ— Not found'}")
116
+ print(f" Stamp: {'βœ“ Detected' if fields['stamp']['present'] else 'βœ— Not found'}")
117
+
118
+ if result.get('warnings'):
119
+ print(f"\n⚠️ Warnings: {', '.join(result['warnings'])}")
120
+
121
+ except requests.exceptions.RequestException as e:
122
+ print(f"❌ Error: {e}")
123
+ else:
124
+ print(f"⚠️ Sample invoice not found at: {invoice_path}")
125
+ print(" Please provide a valid invoice image path.")
126
+
127
+ # Example 2: Batch processing
128
+ print("\n" + "=" * 60)
129
+ print("Example 2: Batch Invoice Processing")
130
+ print("=" * 60)
131
+
132
+ # Replace with your invoice paths
133
+ batch_paths = ["invoice_001.png", "invoice_002.png"]
134
+
135
+ existing_paths = [p for p in batch_paths if Path(p).exists()]
136
+
137
+ if existing_paths:
138
+ try:
139
+ results = client.extract_batch(existing_paths)
140
+
141
+ print(f"\nπŸ“¦ Processed {len(results)} invoices")
142
+
143
+ for i, result in enumerate(results, 1):
144
+ if 'error' in result:
145
+ print(f"\n {i}. ❌ {result.get('filename', 'Unknown')}: {result['error']}")
146
+ else:
147
+ print(f"\n {i}. βœ… {result['doc_id']}")
148
+ print(f" Confidence: {result['confidence']}")
149
+ print(f" Dealer: {result['fields']['dealer_name']}")
150
+ print(f" Cost: β‚Ή{result['fields']['asset_cost']:,}")
151
+
152
+ except requests.exceptions.RequestException as e:
153
+ print(f"❌ Error: {e}")
154
+ else:
155
+ print("⚠️ No valid invoice images found for batch processing")
156
+
157
+ # Example 3: Save results to JSON
158
+ print("\n" + "=" * 60)
159
+ print("Example 3: Save Results to JSON")
160
+ print("=" * 60)
161
+
162
+ if Path(invoice_path).exists():
163
+ try:
164
+ result = client.extract_invoice(invoice_path)
165
+
166
+ output_file = "extraction_result.json"
167
+ with open(output_file, 'w', encoding='utf-8') as f:
168
+ json.dump(result, f, indent=2, ensure_ascii=False)
169
+
170
+ print(f"\nβœ… Results saved to: {output_file}")
171
+
172
+ except Exception as e:
173
+ print(f"❌ Error: {e}")
174
+
175
+ print("\n" + "=" * 60)
176
+ print("Examples complete!")
177
+ print("=" * 60)
inference.py CHANGED
@@ -296,6 +296,7 @@ class InferenceProcessor:
296
  dict: Complete JSON output with all fields
297
  """
298
  total_start = time.time()
 
299
 
300
  # Generate doc_id if not provided
301
  if doc_id is None:
@@ -303,23 +304,33 @@ class InferenceProcessor:
303
  doc_id = os.path.splitext(os.path.basename(image_path))[0]
304
 
305
  # Step 1: Preprocess image
 
306
  image = InferenceProcessor.preprocess_image(image_path)
 
307
 
308
  # Step 2: YOLO Detection
 
309
  signature_info, stamp_info, signature_conf, stamp_conf = model_manager.detect_sign_stamp(image_path)
 
310
 
311
  # Step 3: VLM Extraction
 
312
  vlm_output, vlm_latency = InferenceProcessor.run_vlm_extraction(image)
 
313
 
314
  # Clean up image
315
  image.close()
316
  del image
317
 
318
  # Step 4: Parse JSON
 
319
  raw_json = InferenceProcessor.extract_json_from_output(vlm_output)
 
320
 
321
  # Step 5: Validate and fix
 
322
  validated_fields, field_confidence, warnings = InferenceProcessor.validate_prediction(raw_json)
 
323
 
324
  # Add signature and stamp
325
  validated_fields["signature"] = signature_info
@@ -344,6 +355,7 @@ class InferenceProcessor:
344
  "fields": validated_fields,
345
  "confidence": overall_confidence,
346
  "processing_time_sec": round(total_time, 2),
 
347
  "cost_estimate_usd": round(cost_estimate, 6),
348
  "warnings": warnings if warnings else None
349
  }
 
296
  dict: Complete JSON output with all fields
297
  """
298
  total_start = time.time()
299
+ timing_breakdown = {}
300
 
301
  # Generate doc_id if not provided
302
  if doc_id is None:
 
304
  doc_id = os.path.splitext(os.path.basename(image_path))[0]
305
 
306
  # Step 1: Preprocess image
307
+ t1 = time.time()
308
  image = InferenceProcessor.preprocess_image(image_path)
309
+ timing_breakdown['image_preprocessing'] = round(time.time() - t1, 3)
310
 
311
  # Step 2: YOLO Detection
312
+ t2 = time.time()
313
  signature_info, stamp_info, signature_conf, stamp_conf = model_manager.detect_sign_stamp(image_path)
314
+ timing_breakdown['yolo_detection'] = round(time.time() - t2, 3)
315
 
316
  # Step 3: VLM Extraction
317
+ t3 = time.time()
318
  vlm_output, vlm_latency = InferenceProcessor.run_vlm_extraction(image)
319
+ timing_breakdown['vlm_inference'] = round(vlm_latency, 3)
320
 
321
  # Clean up image
322
  image.close()
323
  del image
324
 
325
  # Step 4: Parse JSON
326
+ t4 = time.time()
327
  raw_json = InferenceProcessor.extract_json_from_output(vlm_output)
328
+ timing_breakdown['json_parsing'] = round(time.time() - t4, 3)
329
 
330
  # Step 5: Validate and fix
331
+ t5 = time.time()
332
  validated_fields, field_confidence, warnings = InferenceProcessor.validate_prediction(raw_json)
333
+ timing_breakdown['validation'] = round(time.time() - t5, 3)
334
 
335
  # Add signature and stamp
336
  validated_fields["signature"] = signature_info
 
355
  "fields": validated_fields,
356
  "confidence": overall_confidence,
357
  "processing_time_sec": round(total_time, 2),
358
+ "timing_breakdown": timing_breakdown,
359
  "cost_estimate_usd": round(cost_estimate, 6),
360
  "warnings": warnings if warnings else None
361
  }
model_manager.py CHANGED
@@ -1,145 +1,203 @@
1
- """
2
- Model Manager - Handles loading and caching of YOLO and VLM models
3
- """
4
-
5
- import torch
6
- from transformers import (
7
- Qwen2_5_VLForConditionalGeneration,
8
- AutoProcessor,
9
- BitsAndBytesConfig
10
- )
11
- from ultralytics import YOLO
12
- import os
13
- from typing import Tuple
14
-
15
- from config import (
16
- YOLO_MODEL_PATH,
17
- VLM_MODEL_ID,
18
- QUANTIZATION_CONFIG,
19
- YOLO_CONFIDENCE_THRESHOLD
20
- )
21
-
22
-
23
- class ModelManager:
24
- """Singleton class to manage model loading and inference"""
25
-
26
- _instance = None
27
- _initialized = False
28
-
29
- def __new__(cls):
30
- if cls._instance is None:
31
- cls._instance = super(ModelManager, cls).__new__(cls)
32
- return cls._instance
33
-
34
- def __init__(self):
35
- if not ModelManager._initialized:
36
- self.yolo_model = None
37
- self.vlm_model = None
38
- self.processor = None
39
- ModelManager._initialized = True
40
-
41
- def load_models(self):
42
- """Load both YOLO and VLM models into memory"""
43
- print("πŸš€ Starting model loading...")
44
-
45
- # Load YOLO model
46
- self.yolo_model = self._load_yolo_model()
47
-
48
- # Load VLM model
49
- self.vlm_model, self.processor = self._load_vlm_model()
50
-
51
- print("βœ… All models loaded successfully!")
52
-
53
- def _load_yolo_model(self) -> YOLO:
54
- """Load trained YOLO model for signature and stamp detection"""
55
- if not os.path.exists(YOLO_MODEL_PATH):
56
- raise FileNotFoundError(
57
- f"YOLO model not found at {YOLO_MODEL_PATH}. "
58
- "Please ensure best.pt is in utils/models/"
59
- )
60
-
61
- yolo_model = YOLO(str(YOLO_MODEL_PATH))
62
- print(f"βœ… YOLO model loaded from {YOLO_MODEL_PATH}")
63
- return yolo_model
64
-
65
- def _load_vlm_model(self) -> Tuple:
66
- """
67
- Load Qwen2.5-VL model with 4-bit quantization
68
- Downloads from Hugging Face on first run
69
- """
70
- print(f"πŸ“₯ Loading VLM model: {VLM_MODEL_ID}")
71
- print(" (This will download ~4GB on first run)")
72
-
73
- # Configure 4-bit quantization
74
- bnb_config = BitsAndBytesConfig(
75
- load_in_4bit=QUANTIZATION_CONFIG["load_in_4bit"],
76
- bnb_4bit_quant_type=QUANTIZATION_CONFIG["bnb_4bit_quant_type"],
77
- bnb_4bit_compute_dtype=getattr(torch, QUANTIZATION_CONFIG["bnb_4bit_compute_dtype"]),
78
- bnb_4bit_use_double_quant=QUANTIZATION_CONFIG["bnb_4bit_use_double_quant"]
79
- )
80
-
81
- # Load processor
82
- processor = AutoProcessor.from_pretrained(
83
- VLM_MODEL_ID,
84
- trust_remote_code=True
85
- )
86
-
87
- # Load model with quantization
88
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
89
- VLM_MODEL_ID,
90
- quantization_config=bnb_config,
91
- device_map="auto",
92
- torch_dtype=torch.bfloat16,
93
- trust_remote_code=True
94
- )
95
-
96
- model.eval()
97
- print(f"βœ… Qwen2.5-VL model loaded successfully")
98
-
99
- return model, processor
100
-
101
- def detect_sign_stamp(self, image_path: str):
102
- """
103
- Detect signature and stamp in the image using YOLO
104
-
105
- Returns:
106
- tuple: (signature_info, stamp_info, signature_conf, stamp_conf)
107
- """
108
- if self.yolo_model is None:
109
- raise RuntimeError("YOLO model not loaded. Call load_models() first.")
110
-
111
- results = self.yolo_model(image_path, verbose=False)[0]
112
-
113
- signature_info = {"present": False, "bbox": None}
114
- stamp_info = {"present": False, "bbox": None}
115
- signature_conf = 0.0
116
- stamp_conf = 0.0
117
-
118
- if results.boxes is not None:
119
- for box in results.boxes:
120
- cls_id = int(box.cls[0])
121
- conf = float(box.conf[0])
122
-
123
- if conf > YOLO_CONFIDENCE_THRESHOLD:
124
- bbox = box.xyxy[0].cpu().numpy().tolist()
125
- bbox = [int(coord) for coord in bbox]
126
-
127
- # Class 0: signature, Class 1: stamp
128
- if cls_id == 0 and conf > signature_conf:
129
- signature_info = {"present": True, "bbox": bbox}
130
- signature_conf = conf
131
- elif cls_id == 1 and conf > stamp_conf:
132
- stamp_info = {"present": True, "bbox": bbox}
133
- stamp_conf = conf
134
-
135
- return signature_info, stamp_info, signature_conf, stamp_conf
136
-
137
- def is_loaded(self) -> bool:
138
- """Check if models are loaded"""
139
- return (self.yolo_model is not None and
140
- self.vlm_model is not None and
141
- self.processor is not None)
142
-
143
-
144
- # Global model manager instance
145
- model_manager = ModelManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Manager - Handles loading and caching of YOLO and VLM models
3
+ """
4
+
5
+ import torch
6
+ from transformers import (
7
+ Qwen2_5_VLForConditionalGeneration,
8
+ AutoProcessor,
9
+ BitsAndBytesConfig
10
+ )
11
+ from ultralytics import YOLO
12
+ import os
13
+ from typing import Tuple
14
+
15
+ from config import (
16
+ YOLO_MODEL_PATH,
17
+ VLM_MODEL_ID,
18
+ QUANTIZATION_CONFIG,
19
+ YOLO_CONFIDENCE_THRESHOLD
20
+ )
21
+
22
+
23
+ class ModelManager:
24
+ """Singleton class to manage model loading and inference"""
25
+
26
+ _instance = None
27
+ _initialized = False
28
+
29
+ def __new__(cls):
30
+ if cls._instance is None:
31
+ cls._instance = super(ModelManager, cls).__new__(cls)
32
+ return cls._instance
33
+
34
+ def __init__(self):
35
+ if not ModelManager._initialized:
36
+ self.yolo_model = None
37
+ self.vlm_model = None
38
+ self.processor = None
39
+ ModelManager._initialized = True
40
+
41
+ def load_models(self):
42
+ """Load both YOLO and VLM models into memory"""
43
+ print("πŸš€ Starting model loading...")
44
+
45
+ # Load YOLO model
46
+ self.yolo_model = self._load_yolo_model()
47
+
48
+ # Load VLM model
49
+ self.vlm_model, self.processor = self._load_vlm_model()
50
+
51
+ # Warm up models to initialize CUDA context
52
+ self._warmup_models()
53
+
54
+ print("βœ… All models loaded successfully!")
55
+
56
+ def _load_yolo_model(self) -> YOLO:
57
+ """Load trained YOLO model for signature and stamp detection"""
58
+ if not os.path.exists(YOLO_MODEL_PATH):
59
+ raise FileNotFoundError(
60
+ f"YOLO model not found at {YOLO_MODEL_PATH}. "
61
+ "Please ensure best.pt is in utils/models/"
62
+ )
63
+
64
+ yolo_model = YOLO(str(YOLO_MODEL_PATH))
65
+ print(f"βœ… YOLO model loaded from {YOLO_MODEL_PATH}")
66
+ return yolo_model
67
+
68
+ def _load_vlm_model(self) -> Tuple:
69
+ """
70
+ Load Qwen2.5-VL model with 4-bit quantization
71
+ Downloads from Hugging Face on first run
72
+ """
73
+ print(f"πŸ“₯ Loading VLM model: {VLM_MODEL_ID}")
74
+ print(" (This will download ~4GB on first run)")
75
+
76
+ # Configure 4-bit quantization
77
+ bnb_config = BitsAndBytesConfig(
78
+ load_in_4bit=QUANTIZATION_CONFIG["load_in_4bit"],
79
+ bnb_4bit_quant_type=QUANTIZATION_CONFIG["bnb_4bit_quant_type"],
80
+ bnb_4bit_compute_dtype=getattr(torch, QUANTIZATION_CONFIG["bnb_4bit_compute_dtype"]),
81
+ bnb_4bit_use_double_quant=QUANTIZATION_CONFIG["bnb_4bit_use_double_quant"]
82
+ )
83
+
84
+ # Load processor
85
+ processor = AutoProcessor.from_pretrained(
86
+ VLM_MODEL_ID,
87
+ trust_remote_code=True
88
+ )
89
+
90
+ # Load model with quantization
91
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
92
+ VLM_MODEL_ID,
93
+ quantization_config=bnb_config,
94
+ device_map="auto",
95
+ torch_dtype=torch.bfloat16,
96
+ trust_remote_code=True
97
+ )
98
+
99
+ model.eval()
100
+ print(f"βœ… Qwen2.5-VL model loaded successfully")
101
+
102
+ return model, processor
103
+
104
+ def _warmup_models(self):
105
+ """Warm up models with a dummy inference to initialize CUDA context"""
106
+ print("πŸ”₯ Warming up models (initializing CUDA context)...")
107
+ import time
108
+ from PIL import Image
109
+ import numpy as np
110
+
111
+ warmup_start = time.time()
112
+
113
+ # Create a small dummy image
114
+ dummy_image = Image.fromarray(np.ones((100, 100, 3), dtype=np.uint8) * 255)
115
+
116
+ try:
117
+ # Warm up VLM
118
+ messages = [
119
+ {
120
+ "role": "user",
121
+ "content": [
122
+ {"type": "image", "image": dummy_image},
123
+ {"type": "text", "text": "warm up"}
124
+ ]
125
+ }
126
+ ]
127
+
128
+ from qwen_vl_utils import process_vision_info
129
+ text = self.processor.apply_chat_template(
130
+ messages,
131
+ tokenize=False,
132
+ add_generation_prompt=True
133
+ )
134
+ image_inputs, video_inputs = process_vision_info(messages)
135
+ inputs = self.processor(
136
+ text=[text],
137
+ images=image_inputs,
138
+ videos=video_inputs,
139
+ padding=True,
140
+ return_tensors="pt",
141
+ )
142
+ inputs = inputs.to("cuda")
143
+
144
+ # Run a quick inference
145
+ with torch.no_grad():
146
+ _ = self.vlm_model.generate(**inputs, max_new_tokens=5)
147
+
148
+ # Clean up
149
+ del inputs
150
+ if torch.cuda.is_available():
151
+ torch.cuda.empty_cache()
152
+
153
+ warmup_time = time.time() - warmup_start
154
+ print(f"βœ… Models warmed up in {warmup_time:.2f}s (CUDA context initialized)")
155
+
156
+ except Exception as e:
157
+ print(f"⚠️ Warmup failed (non-critical): {e}")
158
+
159
+ def detect_sign_stamp(self, image_path: str):
160
+ """
161
+ Detect signature and stamp in the image using YOLO
162
+
163
+ Returns:
164
+ tuple: (signature_info, stamp_info, signature_conf, stamp_conf)
165
+ """
166
+ if self.yolo_model is None:
167
+ raise RuntimeError("YOLO model not loaded. Call load_models() first.")
168
+
169
+ results = self.yolo_model(image_path, verbose=False)[0]
170
+
171
+ signature_info = {"present": False, "bbox": None}
172
+ stamp_info = {"present": False, "bbox": None}
173
+ signature_conf = 0.0
174
+ stamp_conf = 0.0
175
+
176
+ if results.boxes is not None:
177
+ for box in results.boxes:
178
+ cls_id = int(box.cls[0])
179
+ conf = float(box.conf[0])
180
+
181
+ if conf > YOLO_CONFIDENCE_THRESHOLD:
182
+ bbox = box.xyxy[0].cpu().numpy().tolist()
183
+ bbox = [int(coord) for coord in bbox]
184
+
185
+ # Class 0: signature, Class 1: stamp
186
+ if cls_id == 0 and conf > signature_conf:
187
+ signature_info = {"present": True, "bbox": bbox}
188
+ signature_conf = conf
189
+ elif cls_id == 1 and conf > stamp_conf:
190
+ stamp_info = {"present": True, "bbox": bbox}
191
+ stamp_conf = conf
192
+
193
+ return signature_info, stamp_info, signature_conf, stamp_conf
194
+
195
+ def is_loaded(self) -> bool:
196
+ """Check if models are loaded"""
197
+ return (self.yolo_model is not None and
198
+ self.vlm_model is not None and
199
+ self.processor is not None)
200
+
201
+
202
+ # Global model manager instance
203
+ model_manager = ModelManager()
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
- torch
2
- transformers
3
- ultralytics
4
- pillow
5
- accelerate
6
- bitsandbytes
7
- opencv-python
8
- pyyaml
9
- qwen-vl-utils[decord]
10
- fastapi
11
- uvicorn[standard]
12
  python-multipart
 
1
+ torch
2
+ transformers
3
+ ultralytics
4
+ pillow
5
+ accelerate
6
+ bitsandbytes
7
+ opencv-python
8
+ pyyaml
9
+ qwen-vl-utils[decord]
10
+ fastapi
11
+ uvicorn[standard]
12
  python-multipart