solomonodum commited on
Commit
bbc8b36
·
verified ·
1 Parent(s): 5e48f93

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. batch_api_2.py +381 -0
  2. nanonets_ocr_2.py +125 -0
  3. single_inferencing_2.py +310 -0
batch_api_2.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import base64
5
+ import asyncio
6
+ import concurrent.futures
7
+ from typing import Dict, Optional, List, Union
8
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ import uvicorn
12
+ from PIL import Image
13
+ import io
14
+ from contextlib import asynccontextmanager
15
+ from prometheus_fastapi_instrumentator import Instrumentator
16
+
17
+ # Add the current directory to the path so we can import the llama_inferencing module
18
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
19
+
20
+ from single_inferencing_2 import SingleImageInference
21
+ from utils.prompt_utils import create_query, parse_label, create_query_updated
22
+ from utils.image_utils import encode_pil_image_to_base64
23
+
24
+ # --- GLOBAL VARS (Constants, not the inferencer itself) ---
25
+ LOG_DIR = os.getenv("LOG_DIR", "inference_logs")
26
+ SEGMENTATION_DEVICE_ID = int(os.getenv("SEGMENTATION_DEVICE_ID", "7"))
27
+ ENABLE_BBOX_DETECTION = os.getenv("ENABLE_BBOX_DETECTION", "False").lower() == "true"
28
+ VLLM_SERVER_URL: Optional[str] = None
29
+ MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "10")) # Maximum batch size
30
+ MAX_CONCURRENT_WORKERS = int(os.getenv("MAX_CONCURRENT_WORKERS", "4")) # Concurrent processing limit
31
+
32
+ # --- Lifespan Context Manager ---
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ """
36
+ Handles startup and shutdown events for the FastAPI application.
37
+ Initializes the inferencer during startup.
38
+ """
39
+ global VLLM_SERVER_URL
40
+
41
+ if VLLM_SERVER_URL is None:
42
+ print("ERROR: VLLM_SERVER_URL was not set before lifespan start. Exiting.", flush=True)
43
+ sys.exit(1)
44
+
45
+ print(f"Lifespan: Initializing inferencer for this worker with VLLM URL: {VLLM_SERVER_URL}", flush=True)
46
+ try:
47
+ app.state.inferencer = SingleImageInference(
48
+ server_url=VLLM_SERVER_URL,
49
+ log_dir=LOG_DIR,
50
+ segmentation_device_id=SEGMENTATION_DEVICE_ID,
51
+ enable_bbox_detection=True
52
+ )
53
+
54
+ # Initialize thread pool for batch processing
55
+ app.state.thread_pool = concurrent.futures.ThreadPoolExecutor(
56
+ max_workers=MAX_CONCURRENT_WORKERS
57
+ )
58
+
59
+ print("Lifespan: Inferencer and thread pool successfully initialized.", flush=True)
60
+ except Exception as e:
61
+ print(f"Lifespan ERROR: Failed to initialize Inferencer: {e}", flush=True)
62
+ app.state.inferencer = None
63
+ app.state.thread_pool = None
64
+ yield
65
+
66
+ # Shutdown cleanup
67
+ print("Lifespan: Application shutdown. Performing cleanup.", flush=True)
68
+ if hasattr(app.state, 'thread_pool') and app.state.thread_pool:
69
+ app.state.thread_pool.shutdown(wait=True)
70
+ if hasattr(app.state.inferencer, 'close'):
71
+ app.state.inferencer.close()
72
+
73
+ # Initialize FastAPI app with lifespan
74
+ app = FastAPI(
75
+ title="Llama Inferencing API with Batch Processing",
76
+ description="API for running inference on images using Llama model - supports both single and batch processing",
77
+ lifespan=lifespan
78
+ )
79
+
80
+ # Add CORS middleware
81
+ app.add_middleware(
82
+ CORSMiddleware,
83
+ allow_origins=["*"],
84
+ allow_credentials=True,
85
+ allow_methods=["*"],
86
+ allow_headers=["*"],
87
+ )
88
+
89
+ Instrumentator().instrument(app).expose(app)
90
+
91
+ # --- BaseModel Definitions ---
92
+ class InferenceRequest(BaseModel):
93
+ data: List[Dict[str, Union[str, float]]]
94
+
95
+ class BatchInferenceRequest(BaseModel):
96
+ data: List[Dict[str, Union[str, float]]]
97
+ batch_size: Optional[int] = None # Optional batch size override
98
+
99
+ class InferenceResponse(BaseModel):
100
+ body: Dict
101
+ meta: Dict
102
+ error: str
103
+
104
+ class BatchInferenceResponse(BaseModel):
105
+ body: Dict
106
+ meta: Dict
107
+ error: str
108
+ batch_info: Dict # Additional batch processing info
109
+
110
+ def process_single_item(inferencer, item: Dict, temp_dir: str = "/tmp") -> Dict:
111
+ """
112
+ Process a single inference item - extracted for reuse in batch processing
113
+ """
114
+ try:
115
+ # Extract fields from the item
116
+ workorder_id = item["workorder_id"]
117
+ image_id = item["image_id"]
118
+ doc_type = item["doc_type"]
119
+ business_type = item["business_type"]
120
+ workorder_type = item["workorder_type"]
121
+ image_base64 = item["image"]
122
+
123
+ # Decode the base64 image
124
+ image_content = base64.b64decode(image_base64)
125
+ pil_image = Image.open(io.BytesIO(image_content))
126
+
127
+ # Create a temporary file path for the image
128
+ temp_image_path = f"{temp_dir}/{image_id}_{workorder_id}.jpg"
129
+ pil_image.save(temp_image_path)
130
+
131
+ # Create query for the image
132
+ query = create_query_updated(
133
+ temp_image_path,
134
+ doc_type.lower(),
135
+ [item.get("task_name", "default")],
136
+ [item.get("format_name", "reasoning_specrec")]
137
+ )[0]
138
+
139
+ query["image"] = pil_image
140
+ query["doc_type"] = doc_type.upper()
141
+
142
+ print(f"Processing WORKORDERID: {workorder_id}, DOCTYPE: {query['doc_type']}", flush=True)
143
+
144
+ # Run inference using the initialized inferencer
145
+ inference_result = inferencer.run_inference(query, item.get("temperature", 0.1))
146
+
147
+ # Parse the response
148
+ try:
149
+ json_str = inference_result["response"].strip("`json\n")
150
+ raw_response = json.loads(json_str)
151
+ except Exception as e:
152
+ print(f"Failed to parse model response: {e}. Raw response: {inference_result.get('response')}", flush=True)
153
+ raw_response = {
154
+ "reasoning": "Failed to parse model response",
155
+ "evaluation_result": "UNKNOWN"
156
+ }
157
+
158
+ evaluation_result = raw_response.get("evaluation_result", "UNCERTAIN")
159
+
160
+ # Normalize model_decision
161
+ if evaluation_result == "VALID":
162
+ model_decision = "VALID_INSTALL"
163
+ review_queue = "GREEN"
164
+ elif evaluation_result == "INVALID":
165
+ model_decision = "INVALID_INSTALL"
166
+ review_queue = "RED"
167
+ else:
168
+ model_decision = "UNCERTAIN"
169
+ review_queue = "YELLOW"
170
+
171
+ # Extract embedding from raw_response if available
172
+ embedding = raw_response.get("embedding")
173
+
174
+ formatted_result = {
175
+ "workorder_id": workorder_id,
176
+ "image_id": image_id,
177
+ "doc_type": doc_type,
178
+ "business_type": business_type,
179
+ "workorder_type": workorder_type,
180
+ "confidence_threshold": 0,
181
+ "model_output": {
182
+ "model_decision_reason": raw_response.get("reasoning", ""),
183
+ "model_decision": model_decision,
184
+ "recommendation": raw_response.get("recommendations", ""),
185
+ # "serial_id": raw_response.get("serial_id", ""),
186
+ "serial_id": "12345",
187
+ "power_meter_reading": raw_response.get("power_meter_reading", ""),
188
+ "review_queue": review_queue,
189
+ "confidence_score": 0,
190
+ }
191
+ }
192
+
193
+ # Add embedding to response if available
194
+ if embedding is not None:
195
+ formatted_result["embedding"] = embedding
196
+
197
+ # Clean up the temporary file
198
+ if os.path.exists(temp_image_path):
199
+ os.remove(temp_image_path)
200
+
201
+ return {"success": True, "result": formatted_result, "error": None}
202
+
203
+ except Exception as e:
204
+ # Clean up the temporary file in case of error
205
+ if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
206
+ os.remove(temp_image_path)
207
+
208
+ print(f"Error processing item {item.get('workorder_id', 'unknown')}: {e}", flush=True)
209
+ return {"success": False, "result": None, "error": str(e)}
210
+
211
+ async def process_batch_chunk(inferencer, chunk: List[Dict], executor) -> List[Dict]:
212
+ """
213
+ Process a chunk of items concurrently using thread pool
214
+ """
215
+ loop = asyncio.get_event_loop()
216
+ futures = [
217
+ loop.run_in_executor(executor, process_single_item, inferencer, item)
218
+ for item in chunk
219
+ ]
220
+ return await asyncio.gather(*futures)
221
+
222
+ @app.post("/infer/", response_model=InferenceResponse)
223
+ async def run_inference(request: InferenceRequest):
224
+ """
225
+ Run inference on a single image and return the results.
226
+ """
227
+ if app.state.inferencer is None:
228
+ raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.")
229
+
230
+ try:
231
+ item = request.data[0]
232
+ result = process_single_item(app.state.inferencer, item)
233
+
234
+ if result["success"]:
235
+ return {
236
+ "body": {"data": [result["result"]]},
237
+ "meta": {},
238
+ "error": ""
239
+ }
240
+ else:
241
+ return {
242
+ "body": {"data": []},
243
+ "meta": {},
244
+ "error": result["error"]
245
+ }
246
+ except Exception as e:
247
+ print(f"API - Error during inference: {e}", flush=True)
248
+ return {
249
+ "body": {"data": []},
250
+ "meta": {},
251
+ "error": str(e)
252
+ }
253
+
254
+ @app.post("/infer/batch/", response_model=BatchInferenceResponse)
255
+ async def run_batch_inference(request: BatchInferenceRequest):
256
+ """
257
+ Run inference on multiple images in batches with concurrent processing.
258
+ """
259
+ if app.state.inferencer is None:
260
+ raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.")
261
+
262
+ if app.state.thread_pool is None:
263
+ raise HTTPException(status_code=500, detail="Thread pool not initialized.")
264
+
265
+ try:
266
+ batch_size = request.batch_size or MAX_BATCH_SIZE
267
+ data = request.data
268
+
269
+ # Validate batch size
270
+ if len(data) > MAX_BATCH_SIZE * 5: # Allow up to 5x max batch size
271
+ raise HTTPException(
272
+ status_code=400,
273
+ detail=f"Batch too large. Maximum allowed: {MAX_BATCH_SIZE * 5}, received: {len(data)}"
274
+ )
275
+
276
+ print(f"Processing batch of {len(data)} items with batch_size={batch_size}", flush=True)
277
+
278
+ # Split data into chunks
279
+ chunks = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
280
+
281
+ all_results = []
282
+ successful_count = 0
283
+ failed_count = 0
284
+
285
+ # Process chunks sequentially to avoid overwhelming the system
286
+ for i, chunk in enumerate(chunks):
287
+ print(f"Processing chunk {i + 1}/{len(chunks)} with {len(chunk)} items", flush=True)
288
+
289
+ chunk_results = await process_batch_chunk(
290
+ app.state.inferencer,
291
+ chunk,
292
+ app.state.thread_pool
293
+ )
294
+
295
+ # Collect results and count successes/failures
296
+ for result in chunk_results:
297
+ if result["success"]:
298
+ all_results.append(result["result"])
299
+ successful_count += 1
300
+ else:
301
+ failed_count += 1
302
+ print(f"Failed to process item: {result['error']}", flush=True)
303
+
304
+ batch_info = {
305
+ "total_items": len(data),
306
+ "successful_items": successful_count,
307
+ "failed_items": failed_count,
308
+ "batch_size_used": batch_size,
309
+ "total_chunks": len(chunks)
310
+ }
311
+
312
+ return {
313
+ "body": {"data": all_results},
314
+ "meta": {"processing_time": "completed"},
315
+ "error": f"{failed_count} items failed" if failed_count > 0 else "",
316
+ "batch_info": batch_info
317
+ }
318
+
319
+ except Exception as e:
320
+ print(f"API - Error during batch inference: {e}", flush=True)
321
+ return {
322
+ "body": {"data": []},
323
+ "meta": {},
324
+ "error": str(e),
325
+ "batch_info": {"total_items": len(request.data), "successful_items": 0, "failed_items": len(request.data)}
326
+ }
327
+
328
+ @app.get("/health")
329
+ async def health_check():
330
+ """
331
+ Health check endpoint.
332
+ """
333
+ if app.state.inferencer is None:
334
+ raise HTTPException(status_code=503, detail="Inferencer not initialized or failed to load")
335
+
336
+ if app.state.thread_pool is None:
337
+ raise HTTPException(status_code=503, detail="Thread pool not initialized")
338
+
339
+ return {
340
+ "status": "healthy",
341
+ "max_batch_size": MAX_BATCH_SIZE,
342
+ "max_concurrent_workers": MAX_CONCURRENT_WORKERS
343
+ }
344
+
345
+ @app.get("/")
346
+ async def root():
347
+ """
348
+ Root endpoint for basic health check.
349
+ """
350
+ return {
351
+ "status": "API is running",
352
+ "service": "Llama Inferencing API with Batch Processing",
353
+ "endpoints": {
354
+ "single_inference": "/infer/",
355
+ "batch_inference": "/infer/batch/",
356
+ "health": "/health"
357
+ }
358
+ }
359
+
360
+ if __name__ == "__main__":
361
+ import argparse
362
+
363
+ parser = argparse.ArgumentParser()
364
+ parser.add_argument("--port", type=int, default=8877, help="API port")
365
+ parser.add_argument("--vllm-url", type=str, default="http://localhost:8000/v1", help="VLLM server URL")
366
+ parser.add_argument("--max-batch-size", type=int, default=10, help="Maximum batch size")
367
+ parser.add_argument("--max-workers", type=int, default=4, help="Maximum concurrent workers")
368
+
369
+ args = parser.parse_args()
370
+
371
+ # Store configuration globally
372
+ VLLM_SERVER_URL = args.vllm_url
373
+ MAX_BATCH_SIZE = args.max_batch_size
374
+ MAX_CONCURRENT_WORKERS = args.max_workers
375
+
376
+ print(f"Starting API server on port {args.port}", flush=True)
377
+ print(f"VLLM URL: {args.vllm_url}", flush=True)
378
+ print(f"Max batch size: {MAX_BATCH_SIZE}", flush=True)
379
+ print(f"Max concurrent workers: {MAX_CONCURRENT_WORKERS}", flush=True)
380
+
381
+ uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False)
nanonets_ocr_2.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq, pipeline
4
+ from PIL import Image
5
+ import threading
6
+ from typing import List, Dict, Tuple
7
+
8
+ # Global pipeline instance
9
+ _pipeline_instance = None
10
+ _device = None
11
+ _lock = threading.Lock()
12
+
13
+ def download_and_save_model(model_name="nanonets/Nanonets-OCR-s", local_dir="/app/models/nanonets-ocr"):
14
+ """Download and save model to a local directory"""
15
+ os.makedirs(local_dir, exist_ok=True)
16
+ print(f"Downloading model to: {local_dir}")
17
+
18
+ # Download processor and model - using trust_remote_code consistently
19
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
20
+ model = AutoModelForVision2Seq.from_pretrained(model_name, trust_remote_code=True)
21
+
22
+ # Save to local directory
23
+ processor.save_pretrained(local_dir)
24
+ model.save_pretrained(local_dir)
25
+
26
+ print(f"Model saved to: {local_dir}")
27
+ return local_dir
28
+
29
+ def load_model(model_path, device_id=0):
30
+ """Load model from local path using pipeline - based on working old code"""
31
+ if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "config.json")):
32
+ print(f"Loading model from: {model_path} on device {device_id}")
33
+ # Use the simple approach from old code that was working, with device_id parameter
34
+ return pipeline("image-text-to-text", model=model_path, device=device_id, trust_remote_code=True)
35
+ else:
36
+ print("Local model not found, downloading...")
37
+ download_and_save_model(local_dir=model_path)
38
+ return pipeline("image-text-to-text", model=model_path, device=device_id, trust_remote_code=True)
39
+
40
+ def initialize_nanonets_model(device_id=0):
41
+ """Initialize the nanonets model pipeline - thread-safe version of old working code"""
42
+ global _pipeline_instance, _device
43
+
44
+ with _lock:
45
+ if _pipeline_instance is None:
46
+ local_model_path = "/app/models/nanonets-ocr"
47
+ _device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"
48
+
49
+ print(f"Loading Nanonets OCR model on {_device} (device_id: {device_id})")
50
+
51
+ # Use the simple approach from old code with device_id parameter
52
+ _pipeline_instance = load_model(local_model_path, device_id)
53
+ print("Nanonets OCR model initialized successfully")
54
+
55
+ return _pipeline_instance
56
+
57
+ def extract_single_serial_number(pil_image: Image.Image) -> str:
58
+ """Extract serial number from a single image - based on old working code"""
59
+ global _pipeline_instance
60
+
61
+ if _pipeline_instance is None:
62
+ raise RuntimeError("Model not initialized. Call initialize_nanonets_model() first.")
63
+
64
+ try:
65
+ # Use the exact message format from old working code
66
+ messages = [
67
+ {
68
+ "role": "user",
69
+ "content": [
70
+ {"type": "image", "image": pil_image},
71
+ {"type": "text", "text": "Identify the serial number that starts with IN. Strictly return ONLY the alphanumeric serial number string and nothing else."}
72
+ ]
73
+ }
74
+ ]
75
+
76
+ # Use the old working approach for result extraction
77
+ result = _pipeline_instance(messages)
78
+ content = result[0]['generated_text'][-1]['content']
79
+
80
+ return content.strip() if content else ""
81
+
82
+ except Exception as e:
83
+ print(f"Error extracting serial number: {e}")
84
+ return ""
85
+
86
+ def extract_serial_numbers_batch(images_and_indices: List[Tuple[Image.Image, int]]) -> Dict[int, str]:
87
+ """
88
+ Extract serial numbers from multiple images in batch
89
+ Returns dict mapping original index to serial number
90
+ """
91
+ if not images_and_indices:
92
+ return {}
93
+
94
+ print(f"Processing batch of {len(images_and_indices)} S2P_MFIELD images for serial extraction")
95
+
96
+ results = {}
97
+
98
+ # Process each image
99
+ for pil_image, original_index in images_and_indices:
100
+ try:
101
+ serial_id = extract_single_serial_number(pil_image)
102
+ results[original_index] = serial_id
103
+ except Exception as e:
104
+ print(f"Error processing image at index {original_index}: {e}")
105
+ results[original_index] = ""
106
+
107
+ return results
108
+
109
+ def extract_serial_number(pil_image: Image.Image) -> str:
110
+ """
111
+ Single image extraction function - backward compatibility
112
+ Based on the old working code logic
113
+ """
114
+ global _pipeline_instance
115
+
116
+ # Initialize model if not already done (like old code)
117
+ if _pipeline_instance is None:
118
+ initialize_nanonets_model()
119
+
120
+ return extract_single_serial_number(pil_image)
121
+
122
+ def cleanup_gpu_cache():
123
+ """Clean up GPU cache if needed"""
124
+ if torch.cuda.is_available():
125
+ torch.cuda.empty_cache()
single_inferencing_2.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import time
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+ import yaml
8
+ sys.path.append(os.getcwd())
9
+ import re
10
+ import pandas as pd
11
+ import base64
12
+
13
+ from openai import OpenAI
14
+ from PIL import Image
15
+ from sklearn.metrics import classification_report
16
+ from tqdm import tqdm
17
+ from utils.image_utils_yolo import YOLOProcessor
18
+ from utils.image_utils_bbox import updated_add_bbox
19
+ from utils.image_utils import add_bbox, encode_pil_image_to_base64
20
+ from utils.prompt_utils import create_query, parse_label
21
+ from embedding_service import embedding_service
22
+
23
+
24
+ def optimize_image_for_tokens(image, max_size=768):
25
+ """
26
+ Simple image optimization to reduce token consumption
27
+
28
+ Args:
29
+ image: PIL Image object
30
+ max_size: Maximum dimension (768 = good balance of quality vs tokens)
31
+
32
+ Returns:
33
+ Optimized PIL Image
34
+ """
35
+ original_size = image.size
36
+
37
+ # Only resize if image is larger than max_size
38
+ if max(image.size) > max_size:
39
+ # Calculate new size maintaining aspect ratio
40
+ ratio = max_size / max(image.size)
41
+ new_size = (
42
+ int(image.size[0] * ratio),
43
+ int(image.size[1] * ratio)
44
+ )
45
+
46
+ # Resize with high-quality resampling
47
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
48
+
49
+ print(f"Image optimized: {original_size} → {new_size} (estimated {int((max(original_size)/max_size)**2 * 100)}% token reduction)", flush=True)
50
+
51
+ return image
52
+
53
+
54
+ class SingleImageInference:
55
+ def __init__(
56
+ self,
57
+ server_url: str,
58
+ segmentation_device_id: Optional[int] = None,
59
+ log_dir: str = "inference_logs",
60
+ enable_bbox_detection: bool = True,
61
+ ):
62
+ self.log_dir = log_dir
63
+ self.enable_bbox_detection = enable_bbox_detection
64
+ self.segmentation_device_id = segmentation_device_id
65
+
66
+ # initialise client
67
+ self.client = OpenAI(base_url=server_url, api_key="EMPTY")
68
+
69
+ # Ensure log directory exists
70
+ os.makedirs(log_dir, exist_ok=True)
71
+
72
+ # Cache the available models to validate model existence
73
+ self.available_models = self._get_available_models()
74
+
75
+ # Load question mappings
76
+ self.question_mappings = self._load_question_mappings()
77
+
78
+ # Load document type groups
79
+ self.doctype_groups = self._load_doctype_groups()
80
+
81
+ #Load document - bbox class mapping
82
+ self.doctype_detection_mapping = self._load_detection_class_mapping()
83
+
84
+ self.yolo_processor = YOLOProcessor(device="cuda:0")
85
+
86
+ # Initialize CLIP model for embeddings
87
+ try:
88
+ embedding_service.load_model()
89
+ print("✅ CLIP model loaded successfully for embeddings", flush=True)
90
+ except Exception as e:
91
+ print(f"⚠️ Warning: Failed to load CLIP model: {e}", flush=True)
92
+
93
+ def _get_available_models(self) -> List[str]:
94
+ for attempt in range(3):
95
+ try:
96
+ response = self.client.models.list()
97
+ return [model.id for model in response.data]
98
+ except Exception as e:
99
+ if attempt < 2: # Don't wait after the last attempt
100
+ wait_time = (attempt + 1) * 2 # 2s, then 4s
101
+ print(f"Attempt {attempt + 1} failed. Retrying in {wait_time}s...")
102
+ time.sleep(wait_time)
103
+ else:
104
+ print(f"Warning: Could not fetch available models: {str(e)}")
105
+ return []
106
+
107
+ def _load_question_mappings(self) -> Dict[str, Dict[str, str]]:
108
+ """Load question mappings from YAML file."""
109
+ try:
110
+ mapping_file = "/app/meta-jv-reasoning/Teleco-Pilot-Use-Case/framework/utils/question_mappings.yaml"
111
+ with open(mapping_file, 'r') as f:
112
+ return yaml.safe_load(f)
113
+ except Exception as e:
114
+ print(f"Error loading question mappings: {str(e)}")
115
+ return {}
116
+
117
+ def _load_doctype_groups(self) -> Dict[str, List[str]]:
118
+ """Load document type groups from YAML file."""
119
+ try:
120
+ groups_file = "/app/meta-jv-reasoning/Teleco-Pilot-Use-Case/framework/utils/doctype_groups.yaml"
121
+ with open(groups_file, 'r') as f:
122
+ return yaml.safe_load(f)
123
+ except Exception as e:
124
+ print(f"Error loading document type groups: {str(e)}")
125
+ return {}
126
+
127
+ def _get_parent_doctype(self, doc_type: str) -> str:
128
+ """Get the parent document type for a given document type."""
129
+ # Ensure doc_type is uppercase for consistency
130
+ doc_type = doc_type.upper()
131
+
132
+ # Search through the groups to find the parent
133
+ for parent_type, child_types in self.doctype_groups.items():
134
+ if doc_type in child_types:
135
+ return parent_type
136
+
137
+ # If no parent found, return the original type
138
+ return doc_type
139
+
140
+ def _load_detection_class_mapping(self) -> Dict[str, List[List[str]]]:
141
+ """Load and cache detection class mapping from YAML file."""
142
+ try:
143
+ mapping_file = "/app/meta-jv-reasoning/Teleco-Pilot-Use-Case/framework/utils/doctype_bbox_detection_mapping.yaml"
144
+ with open(mapping_file, "r") as f:
145
+ return yaml.safe_load(f)
146
+ except Exception as e:
147
+ print(f"Error loading detection class mapping: {str(e)}")
148
+ return {}
149
+
150
+ def _get_detection_classes_for_doctype(self, parent_doctype: str) -> Optional[List[str]]:
151
+ """Return list of detection classes for a given parent doctype, or None if not found."""
152
+ parent_doctype = parent_doctype.upper()
153
+ class_groups = self.doctype_detection_mapping.get(parent_doctype)
154
+
155
+ if not class_groups:
156
+ return None
157
+
158
+ return [cls for group in class_groups for cls in group]
159
+
160
+ def _map_question(self, question: str, doc_type: str = "default") -> str:
161
+ """Map complex questions to simpler versions based on document type."""
162
+ # Ensure doc_type is uppercase for consistency
163
+ doc_type = doc_type.upper()
164
+
165
+ # Map the document type to its parent type
166
+ parent_doc_type = self._get_parent_doctype(doc_type)
167
+
168
+ # Get mappings for the specific document type, fallback to default if not found
169
+ mappings = self.question_mappings.get(parent_doc_type, self.question_mappings.get("default", {}))
170
+ print(f"Mapping question for doc_type: {doc_type} (parent: {parent_doc_type})")
171
+
172
+ # Clean up the question by removing any numbering prefix
173
+ clean_question = re.sub(r'^\d+\.\s*', '', question)
174
+
175
+ # Try to find an exact match in the mappings
176
+ for complex_q, simple_q in mappings.items():
177
+ # Remove numbering from complex question for comparison
178
+ clean_complex_q = re.sub(r'^\d+\.\s*', '', complex_q)
179
+ if clean_question.startswith(clean_complex_q):
180
+ return simple_q
181
+
182
+ # If no match found, return the original question
183
+ return question
184
+
185
+ def _extract_reasoning(self, raw_response: Dict, doc_type: str = "default") -> str:
186
+ """Extract and format reasoning from raw response into a single string."""
187
+ # Map the document type to its parent type
188
+ parent_doc_type = self._get_parent_doctype(doc_type)
189
+
190
+ reasoning = raw_response.get("reasoning", [])
191
+ if isinstance(reasoning, list):
192
+ formatted_reasoning = []
193
+ for i, item in enumerate(reasoning, 1):
194
+ if ": " in item:
195
+ question, answer = item.split(": ", 1)
196
+ mapped_question = self._map_question(question, parent_doc_type)
197
+ formatted_reasoning.append(f"{i}. {mapped_question}: {answer}")
198
+ else:
199
+ formatted_reasoning.append(f"{i}. {item}")
200
+ return "\n".join(formatted_reasoning)
201
+ return str(reasoning)
202
+
203
+ def run_inference(self, query, temperature: float) -> Dict[str, str]:
204
+ """Run inference on a single image using the vLLM server."""
205
+ doc_type = query.get("doc_type", "default").upper()
206
+ parent_doc_type = self._get_parent_doctype(doc_type)
207
+ print(f"<-- parent_doc_type : {parent_doc_type} -->")
208
+
209
+ try:
210
+ if isinstance(query["image"], Image.Image):
211
+ image = query["image"].convert("RGB")
212
+ else:
213
+ image_data = base64.b64decode(query["image"])
214
+ image = Image.open(BytesIO(image_data)).convert("RGB")
215
+ except Exception as e:
216
+ print(f"Error decoding base64 image: {str(e)}")
217
+ raise ValueError("Invalid image format in query")
218
+
219
+ # Extract embedding from original image before any bbox processing
220
+ try:
221
+ embedding = embedding_service.extract_embedding(image)
222
+ print("✅ Embedding extracted successfully", flush=True)
223
+ except Exception as e:
224
+ print(f"⚠️ Warning: Failed to extract embedding: {e}", flush=True)
225
+ embedding = None
226
+
227
+ image = optimize_image_for_tokens(image, max_size=768)
228
+ # Existing bbox detection logic
229
+ if self.enable_bbox_detection:
230
+ try:
231
+ detection_classes = self._get_detection_classes_for_doctype(parent_doc_type)
232
+ print(f"<< DETECTION CLASS : {detection_classes} >>", flush=True)
233
+
234
+ if detection_classes:
235
+ image_with_boxes = self.yolo_processor.process_bbox(
236
+ image, desired_classes=detection_classes
237
+ )
238
+ if image_with_boxes is not None:
239
+ image = image_with_boxes
240
+ else:
241
+ print(f"No detection class mapping found for parent doctype: {parent_doc_type}", flush=True)
242
+ except Exception as e:
243
+ print(f"Error applying bounding boxes: {str(e)}", flush=True)
244
+
245
+ # Rest of your existing inference code
246
+ image_b64 = encode_pil_image_to_base64(image)
247
+ instruction = f"{query['task_instruction']}\n\n{query['format_instruction']}"
248
+
249
+ if not self.available_models:
250
+ self.available_models = self._get_available_models()
251
+ if not self.available_models:
252
+ raise ValueError(
253
+ "No models available on the server. Please ensure the vLLM server is running and accessible."
254
+ )
255
+ model_id = self.available_models[0]
256
+
257
+ response = self.client.chat.completions.create(
258
+ model=model_id,
259
+ messages=[
260
+ {
261
+ "role": "user",
262
+ "content": [
263
+ {"type": "text", "text": instruction},
264
+ {
265
+ "type": "image_url",
266
+ "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
267
+ },
268
+ ],
269
+ }
270
+ ],
271
+ max_tokens=512,
272
+ temperature=temperature,
273
+ top_p=0.95,
274
+ )
275
+
276
+ raw_response = response.choices[0].message.content
277
+
278
+ try:
279
+ classification_label = parse_label(raw_response)
280
+ success = True
281
+ except AttributeError:
282
+ classification_label = "UNKNOWN"
283
+ success = False
284
+
285
+ try:
286
+ json_str = raw_response.strip("`json\n")
287
+ raw_response_dict = json.loads(json_str)
288
+ formatted_reasoning = self._extract_reasoning(raw_response_dict, doc_type)
289
+ raw_response_dict["reasoning"] = formatted_reasoning
290
+
291
+ # Add embedding to response if available
292
+ if embedding is not None:
293
+ raw_response_dict["embedding"] = embedding
294
+
295
+ raw_response = json.dumps(raw_response_dict)
296
+ except Exception as e:
297
+ print(f"Error formatting reasoning: {str(e)}")
298
+ response_dict = {
299
+ "reasoning": str(raw_response),
300
+ "evaluation_result": classification_label
301
+ }
302
+ if embedding is not None:
303
+ response_dict["embedding"] = embedding
304
+ raw_response = json.dumps(response_dict)
305
+
306
+ return {
307
+ "response": raw_response,
308
+ "label": classification_label,
309
+ "success": success,
310
+ }