Upload folder using huggingface_hub
Browse files- batch_api_2.py +381 -0
- nanonets_ocr_2.py +125 -0
- single_inferencing_2.py +310 -0
batch_api_2.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import base64
|
| 5 |
+
import asyncio
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
from typing import Dict, Optional, List, Union
|
| 8 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
import uvicorn
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import io
|
| 14 |
+
from contextlib import asynccontextmanager
|
| 15 |
+
from prometheus_fastapi_instrumentator import Instrumentator
|
| 16 |
+
|
| 17 |
+
# Add the current directory to the path so we can import the llama_inferencing module
|
| 18 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 19 |
+
|
| 20 |
+
from single_inferencing_2 import SingleImageInference
|
| 21 |
+
from utils.prompt_utils import create_query, parse_label, create_query_updated
|
| 22 |
+
from utils.image_utils import encode_pil_image_to_base64
|
| 23 |
+
|
| 24 |
+
# --- GLOBAL VARS (Constants, not the inferencer itself) ---
|
| 25 |
+
LOG_DIR = os.getenv("LOG_DIR", "inference_logs")
|
| 26 |
+
SEGMENTATION_DEVICE_ID = int(os.getenv("SEGMENTATION_DEVICE_ID", "7"))
|
| 27 |
+
ENABLE_BBOX_DETECTION = os.getenv("ENABLE_BBOX_DETECTION", "False").lower() == "true"
|
| 28 |
+
VLLM_SERVER_URL: Optional[str] = None
|
| 29 |
+
MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "10")) # Maximum batch size
|
| 30 |
+
MAX_CONCURRENT_WORKERS = int(os.getenv("MAX_CONCURRENT_WORKERS", "4")) # Concurrent processing limit
|
| 31 |
+
|
| 32 |
+
# --- Lifespan Context Manager ---
|
| 33 |
+
@asynccontextmanager
|
| 34 |
+
async def lifespan(app: FastAPI):
|
| 35 |
+
"""
|
| 36 |
+
Handles startup and shutdown events for the FastAPI application.
|
| 37 |
+
Initializes the inferencer during startup.
|
| 38 |
+
"""
|
| 39 |
+
global VLLM_SERVER_URL
|
| 40 |
+
|
| 41 |
+
if VLLM_SERVER_URL is None:
|
| 42 |
+
print("ERROR: VLLM_SERVER_URL was not set before lifespan start. Exiting.", flush=True)
|
| 43 |
+
sys.exit(1)
|
| 44 |
+
|
| 45 |
+
print(f"Lifespan: Initializing inferencer for this worker with VLLM URL: {VLLM_SERVER_URL}", flush=True)
|
| 46 |
+
try:
|
| 47 |
+
app.state.inferencer = SingleImageInference(
|
| 48 |
+
server_url=VLLM_SERVER_URL,
|
| 49 |
+
log_dir=LOG_DIR,
|
| 50 |
+
segmentation_device_id=SEGMENTATION_DEVICE_ID,
|
| 51 |
+
enable_bbox_detection=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Initialize thread pool for batch processing
|
| 55 |
+
app.state.thread_pool = concurrent.futures.ThreadPoolExecutor(
|
| 56 |
+
max_workers=MAX_CONCURRENT_WORKERS
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
print("Lifespan: Inferencer and thread pool successfully initialized.", flush=True)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Lifespan ERROR: Failed to initialize Inferencer: {e}", flush=True)
|
| 62 |
+
app.state.inferencer = None
|
| 63 |
+
app.state.thread_pool = None
|
| 64 |
+
yield
|
| 65 |
+
|
| 66 |
+
# Shutdown cleanup
|
| 67 |
+
print("Lifespan: Application shutdown. Performing cleanup.", flush=True)
|
| 68 |
+
if hasattr(app.state, 'thread_pool') and app.state.thread_pool:
|
| 69 |
+
app.state.thread_pool.shutdown(wait=True)
|
| 70 |
+
if hasattr(app.state.inferencer, 'close'):
|
| 71 |
+
app.state.inferencer.close()
|
| 72 |
+
|
| 73 |
+
# Initialize FastAPI app with lifespan
|
| 74 |
+
app = FastAPI(
|
| 75 |
+
title="Llama Inferencing API with Batch Processing",
|
| 76 |
+
description="API for running inference on images using Llama model - supports both single and batch processing",
|
| 77 |
+
lifespan=lifespan
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Add CORS middleware
|
| 81 |
+
app.add_middleware(
|
| 82 |
+
CORSMiddleware,
|
| 83 |
+
allow_origins=["*"],
|
| 84 |
+
allow_credentials=True,
|
| 85 |
+
allow_methods=["*"],
|
| 86 |
+
allow_headers=["*"],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
Instrumentator().instrument(app).expose(app)
|
| 90 |
+
|
| 91 |
+
# --- BaseModel Definitions ---
|
| 92 |
+
class InferenceRequest(BaseModel):
|
| 93 |
+
data: List[Dict[str, Union[str, float]]]
|
| 94 |
+
|
| 95 |
+
class BatchInferenceRequest(BaseModel):
|
| 96 |
+
data: List[Dict[str, Union[str, float]]]
|
| 97 |
+
batch_size: Optional[int] = None # Optional batch size override
|
| 98 |
+
|
| 99 |
+
class InferenceResponse(BaseModel):
|
| 100 |
+
body: Dict
|
| 101 |
+
meta: Dict
|
| 102 |
+
error: str
|
| 103 |
+
|
| 104 |
+
class BatchInferenceResponse(BaseModel):
|
| 105 |
+
body: Dict
|
| 106 |
+
meta: Dict
|
| 107 |
+
error: str
|
| 108 |
+
batch_info: Dict # Additional batch processing info
|
| 109 |
+
|
| 110 |
+
def process_single_item(inferencer, item: Dict, temp_dir: str = "/tmp") -> Dict:
|
| 111 |
+
"""
|
| 112 |
+
Process a single inference item - extracted for reuse in batch processing
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
# Extract fields from the item
|
| 116 |
+
workorder_id = item["workorder_id"]
|
| 117 |
+
image_id = item["image_id"]
|
| 118 |
+
doc_type = item["doc_type"]
|
| 119 |
+
business_type = item["business_type"]
|
| 120 |
+
workorder_type = item["workorder_type"]
|
| 121 |
+
image_base64 = item["image"]
|
| 122 |
+
|
| 123 |
+
# Decode the base64 image
|
| 124 |
+
image_content = base64.b64decode(image_base64)
|
| 125 |
+
pil_image = Image.open(io.BytesIO(image_content))
|
| 126 |
+
|
| 127 |
+
# Create a temporary file path for the image
|
| 128 |
+
temp_image_path = f"{temp_dir}/{image_id}_{workorder_id}.jpg"
|
| 129 |
+
pil_image.save(temp_image_path)
|
| 130 |
+
|
| 131 |
+
# Create query for the image
|
| 132 |
+
query = create_query_updated(
|
| 133 |
+
temp_image_path,
|
| 134 |
+
doc_type.lower(),
|
| 135 |
+
[item.get("task_name", "default")],
|
| 136 |
+
[item.get("format_name", "reasoning_specrec")]
|
| 137 |
+
)[0]
|
| 138 |
+
|
| 139 |
+
query["image"] = pil_image
|
| 140 |
+
query["doc_type"] = doc_type.upper()
|
| 141 |
+
|
| 142 |
+
print(f"Processing WORKORDERID: {workorder_id}, DOCTYPE: {query['doc_type']}", flush=True)
|
| 143 |
+
|
| 144 |
+
# Run inference using the initialized inferencer
|
| 145 |
+
inference_result = inferencer.run_inference(query, item.get("temperature", 0.1))
|
| 146 |
+
|
| 147 |
+
# Parse the response
|
| 148 |
+
try:
|
| 149 |
+
json_str = inference_result["response"].strip("`json\n")
|
| 150 |
+
raw_response = json.loads(json_str)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Failed to parse model response: {e}. Raw response: {inference_result.get('response')}", flush=True)
|
| 153 |
+
raw_response = {
|
| 154 |
+
"reasoning": "Failed to parse model response",
|
| 155 |
+
"evaluation_result": "UNKNOWN"
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
evaluation_result = raw_response.get("evaluation_result", "UNCERTAIN")
|
| 159 |
+
|
| 160 |
+
# Normalize model_decision
|
| 161 |
+
if evaluation_result == "VALID":
|
| 162 |
+
model_decision = "VALID_INSTALL"
|
| 163 |
+
review_queue = "GREEN"
|
| 164 |
+
elif evaluation_result == "INVALID":
|
| 165 |
+
model_decision = "INVALID_INSTALL"
|
| 166 |
+
review_queue = "RED"
|
| 167 |
+
else:
|
| 168 |
+
model_decision = "UNCERTAIN"
|
| 169 |
+
review_queue = "YELLOW"
|
| 170 |
+
|
| 171 |
+
# Extract embedding from raw_response if available
|
| 172 |
+
embedding = raw_response.get("embedding")
|
| 173 |
+
|
| 174 |
+
formatted_result = {
|
| 175 |
+
"workorder_id": workorder_id,
|
| 176 |
+
"image_id": image_id,
|
| 177 |
+
"doc_type": doc_type,
|
| 178 |
+
"business_type": business_type,
|
| 179 |
+
"workorder_type": workorder_type,
|
| 180 |
+
"confidence_threshold": 0,
|
| 181 |
+
"model_output": {
|
| 182 |
+
"model_decision_reason": raw_response.get("reasoning", ""),
|
| 183 |
+
"model_decision": model_decision,
|
| 184 |
+
"recommendation": raw_response.get("recommendations", ""),
|
| 185 |
+
# "serial_id": raw_response.get("serial_id", ""),
|
| 186 |
+
"serial_id": "12345",
|
| 187 |
+
"power_meter_reading": raw_response.get("power_meter_reading", ""),
|
| 188 |
+
"review_queue": review_queue,
|
| 189 |
+
"confidence_score": 0,
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
# Add embedding to response if available
|
| 194 |
+
if embedding is not None:
|
| 195 |
+
formatted_result["embedding"] = embedding
|
| 196 |
+
|
| 197 |
+
# Clean up the temporary file
|
| 198 |
+
if os.path.exists(temp_image_path):
|
| 199 |
+
os.remove(temp_image_path)
|
| 200 |
+
|
| 201 |
+
return {"success": True, "result": formatted_result, "error": None}
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
# Clean up the temporary file in case of error
|
| 205 |
+
if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
|
| 206 |
+
os.remove(temp_image_path)
|
| 207 |
+
|
| 208 |
+
print(f"Error processing item {item.get('workorder_id', 'unknown')}: {e}", flush=True)
|
| 209 |
+
return {"success": False, "result": None, "error": str(e)}
|
| 210 |
+
|
| 211 |
+
async def process_batch_chunk(inferencer, chunk: List[Dict], executor) -> List[Dict]:
|
| 212 |
+
"""
|
| 213 |
+
Process a chunk of items concurrently using thread pool
|
| 214 |
+
"""
|
| 215 |
+
loop = asyncio.get_event_loop()
|
| 216 |
+
futures = [
|
| 217 |
+
loop.run_in_executor(executor, process_single_item, inferencer, item)
|
| 218 |
+
for item in chunk
|
| 219 |
+
]
|
| 220 |
+
return await asyncio.gather(*futures)
|
| 221 |
+
|
| 222 |
+
@app.post("/infer/", response_model=InferenceResponse)
|
| 223 |
+
async def run_inference(request: InferenceRequest):
|
| 224 |
+
"""
|
| 225 |
+
Run inference on a single image and return the results.
|
| 226 |
+
"""
|
| 227 |
+
if app.state.inferencer is None:
|
| 228 |
+
raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.")
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
item = request.data[0]
|
| 232 |
+
result = process_single_item(app.state.inferencer, item)
|
| 233 |
+
|
| 234 |
+
if result["success"]:
|
| 235 |
+
return {
|
| 236 |
+
"body": {"data": [result["result"]]},
|
| 237 |
+
"meta": {},
|
| 238 |
+
"error": ""
|
| 239 |
+
}
|
| 240 |
+
else:
|
| 241 |
+
return {
|
| 242 |
+
"body": {"data": []},
|
| 243 |
+
"meta": {},
|
| 244 |
+
"error": result["error"]
|
| 245 |
+
}
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"API - Error during inference: {e}", flush=True)
|
| 248 |
+
return {
|
| 249 |
+
"body": {"data": []},
|
| 250 |
+
"meta": {},
|
| 251 |
+
"error": str(e)
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
@app.post("/infer/batch/", response_model=BatchInferenceResponse)
|
| 255 |
+
async def run_batch_inference(request: BatchInferenceRequest):
|
| 256 |
+
"""
|
| 257 |
+
Run inference on multiple images in batches with concurrent processing.
|
| 258 |
+
"""
|
| 259 |
+
if app.state.inferencer is None:
|
| 260 |
+
raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.")
|
| 261 |
+
|
| 262 |
+
if app.state.thread_pool is None:
|
| 263 |
+
raise HTTPException(status_code=500, detail="Thread pool not initialized.")
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
batch_size = request.batch_size or MAX_BATCH_SIZE
|
| 267 |
+
data = request.data
|
| 268 |
+
|
| 269 |
+
# Validate batch size
|
| 270 |
+
if len(data) > MAX_BATCH_SIZE * 5: # Allow up to 5x max batch size
|
| 271 |
+
raise HTTPException(
|
| 272 |
+
status_code=400,
|
| 273 |
+
detail=f"Batch too large. Maximum allowed: {MAX_BATCH_SIZE * 5}, received: {len(data)}"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print(f"Processing batch of {len(data)} items with batch_size={batch_size}", flush=True)
|
| 277 |
+
|
| 278 |
+
# Split data into chunks
|
| 279 |
+
chunks = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
|
| 280 |
+
|
| 281 |
+
all_results = []
|
| 282 |
+
successful_count = 0
|
| 283 |
+
failed_count = 0
|
| 284 |
+
|
| 285 |
+
# Process chunks sequentially to avoid overwhelming the system
|
| 286 |
+
for i, chunk in enumerate(chunks):
|
| 287 |
+
print(f"Processing chunk {i + 1}/{len(chunks)} with {len(chunk)} items", flush=True)
|
| 288 |
+
|
| 289 |
+
chunk_results = await process_batch_chunk(
|
| 290 |
+
app.state.inferencer,
|
| 291 |
+
chunk,
|
| 292 |
+
app.state.thread_pool
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Collect results and count successes/failures
|
| 296 |
+
for result in chunk_results:
|
| 297 |
+
if result["success"]:
|
| 298 |
+
all_results.append(result["result"])
|
| 299 |
+
successful_count += 1
|
| 300 |
+
else:
|
| 301 |
+
failed_count += 1
|
| 302 |
+
print(f"Failed to process item: {result['error']}", flush=True)
|
| 303 |
+
|
| 304 |
+
batch_info = {
|
| 305 |
+
"total_items": len(data),
|
| 306 |
+
"successful_items": successful_count,
|
| 307 |
+
"failed_items": failed_count,
|
| 308 |
+
"batch_size_used": batch_size,
|
| 309 |
+
"total_chunks": len(chunks)
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
return {
|
| 313 |
+
"body": {"data": all_results},
|
| 314 |
+
"meta": {"processing_time": "completed"},
|
| 315 |
+
"error": f"{failed_count} items failed" if failed_count > 0 else "",
|
| 316 |
+
"batch_info": batch_info
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
print(f"API - Error during batch inference: {e}", flush=True)
|
| 321 |
+
return {
|
| 322 |
+
"body": {"data": []},
|
| 323 |
+
"meta": {},
|
| 324 |
+
"error": str(e),
|
| 325 |
+
"batch_info": {"total_items": len(request.data), "successful_items": 0, "failed_items": len(request.data)}
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
@app.get("/health")
|
| 329 |
+
async def health_check():
|
| 330 |
+
"""
|
| 331 |
+
Health check endpoint.
|
| 332 |
+
"""
|
| 333 |
+
if app.state.inferencer is None:
|
| 334 |
+
raise HTTPException(status_code=503, detail="Inferencer not initialized or failed to load")
|
| 335 |
+
|
| 336 |
+
if app.state.thread_pool is None:
|
| 337 |
+
raise HTTPException(status_code=503, detail="Thread pool not initialized")
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
"status": "healthy",
|
| 341 |
+
"max_batch_size": MAX_BATCH_SIZE,
|
| 342 |
+
"max_concurrent_workers": MAX_CONCURRENT_WORKERS
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
@app.get("/")
|
| 346 |
+
async def root():
|
| 347 |
+
"""
|
| 348 |
+
Root endpoint for basic health check.
|
| 349 |
+
"""
|
| 350 |
+
return {
|
| 351 |
+
"status": "API is running",
|
| 352 |
+
"service": "Llama Inferencing API with Batch Processing",
|
| 353 |
+
"endpoints": {
|
| 354 |
+
"single_inference": "/infer/",
|
| 355 |
+
"batch_inference": "/infer/batch/",
|
| 356 |
+
"health": "/health"
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
import argparse
|
| 362 |
+
|
| 363 |
+
parser = argparse.ArgumentParser()
|
| 364 |
+
parser.add_argument("--port", type=int, default=8877, help="API port")
|
| 365 |
+
parser.add_argument("--vllm-url", type=str, default="http://localhost:8000/v1", help="VLLM server URL")
|
| 366 |
+
parser.add_argument("--max-batch-size", type=int, default=10, help="Maximum batch size")
|
| 367 |
+
parser.add_argument("--max-workers", type=int, default=4, help="Maximum concurrent workers")
|
| 368 |
+
|
| 369 |
+
args = parser.parse_args()
|
| 370 |
+
|
| 371 |
+
# Store configuration globally
|
| 372 |
+
VLLM_SERVER_URL = args.vllm_url
|
| 373 |
+
MAX_BATCH_SIZE = args.max_batch_size
|
| 374 |
+
MAX_CONCURRENT_WORKERS = args.max_workers
|
| 375 |
+
|
| 376 |
+
print(f"Starting API server on port {args.port}", flush=True)
|
| 377 |
+
print(f"VLLM URL: {args.vllm_url}", flush=True)
|
| 378 |
+
print(f"Max batch size: {MAX_BATCH_SIZE}", flush=True)
|
| 379 |
+
print(f"Max concurrent workers: {MAX_CONCURRENT_WORKERS}", flush=True)
|
| 380 |
+
|
| 381 |
+
uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False)
|
nanonets_ocr_2.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq, pipeline
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import threading
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
|
| 8 |
+
# Global pipeline instance
|
| 9 |
+
_pipeline_instance = None
|
| 10 |
+
_device = None
|
| 11 |
+
_lock = threading.Lock()
|
| 12 |
+
|
| 13 |
+
def download_and_save_model(model_name="nanonets/Nanonets-OCR-s", local_dir="/app/models/nanonets-ocr"):
|
| 14 |
+
"""Download and save model to a local directory"""
|
| 15 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 16 |
+
print(f"Downloading model to: {local_dir}")
|
| 17 |
+
|
| 18 |
+
# Download processor and model - using trust_remote_code consistently
|
| 19 |
+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
| 20 |
+
model = AutoModelForVision2Seq.from_pretrained(model_name, trust_remote_code=True)
|
| 21 |
+
|
| 22 |
+
# Save to local directory
|
| 23 |
+
processor.save_pretrained(local_dir)
|
| 24 |
+
model.save_pretrained(local_dir)
|
| 25 |
+
|
| 26 |
+
print(f"Model saved to: {local_dir}")
|
| 27 |
+
return local_dir
|
| 28 |
+
|
| 29 |
+
def load_model(model_path, device_id=0):
|
| 30 |
+
"""Load model from local path using pipeline - based on working old code"""
|
| 31 |
+
if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "config.json")):
|
| 32 |
+
print(f"Loading model from: {model_path} on device {device_id}")
|
| 33 |
+
# Use the simple approach from old code that was working, with device_id parameter
|
| 34 |
+
return pipeline("image-text-to-text", model=model_path, device=device_id, trust_remote_code=True)
|
| 35 |
+
else:
|
| 36 |
+
print("Local model not found, downloading...")
|
| 37 |
+
download_and_save_model(local_dir=model_path)
|
| 38 |
+
return pipeline("image-text-to-text", model=model_path, device=device_id, trust_remote_code=True)
|
| 39 |
+
|
| 40 |
+
def initialize_nanonets_model(device_id=0):
|
| 41 |
+
"""Initialize the nanonets model pipeline - thread-safe version of old working code"""
|
| 42 |
+
global _pipeline_instance, _device
|
| 43 |
+
|
| 44 |
+
with _lock:
|
| 45 |
+
if _pipeline_instance is None:
|
| 46 |
+
local_model_path = "/app/models/nanonets-ocr"
|
| 47 |
+
_device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu"
|
| 48 |
+
|
| 49 |
+
print(f"Loading Nanonets OCR model on {_device} (device_id: {device_id})")
|
| 50 |
+
|
| 51 |
+
# Use the simple approach from old code with device_id parameter
|
| 52 |
+
_pipeline_instance = load_model(local_model_path, device_id)
|
| 53 |
+
print("Nanonets OCR model initialized successfully")
|
| 54 |
+
|
| 55 |
+
return _pipeline_instance
|
| 56 |
+
|
| 57 |
+
def extract_single_serial_number(pil_image: Image.Image) -> str:
|
| 58 |
+
"""Extract serial number from a single image - based on old working code"""
|
| 59 |
+
global _pipeline_instance
|
| 60 |
+
|
| 61 |
+
if _pipeline_instance is None:
|
| 62 |
+
raise RuntimeError("Model not initialized. Call initialize_nanonets_model() first.")
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
# Use the exact message format from old working code
|
| 66 |
+
messages = [
|
| 67 |
+
{
|
| 68 |
+
"role": "user",
|
| 69 |
+
"content": [
|
| 70 |
+
{"type": "image", "image": pil_image},
|
| 71 |
+
{"type": "text", "text": "Identify the serial number that starts with IN. Strictly return ONLY the alphanumeric serial number string and nothing else."}
|
| 72 |
+
]
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
# Use the old working approach for result extraction
|
| 77 |
+
result = _pipeline_instance(messages)
|
| 78 |
+
content = result[0]['generated_text'][-1]['content']
|
| 79 |
+
|
| 80 |
+
return content.strip() if content else ""
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Error extracting serial number: {e}")
|
| 84 |
+
return ""
|
| 85 |
+
|
| 86 |
+
def extract_serial_numbers_batch(images_and_indices: List[Tuple[Image.Image, int]]) -> Dict[int, str]:
|
| 87 |
+
"""
|
| 88 |
+
Extract serial numbers from multiple images in batch
|
| 89 |
+
Returns dict mapping original index to serial number
|
| 90 |
+
"""
|
| 91 |
+
if not images_and_indices:
|
| 92 |
+
return {}
|
| 93 |
+
|
| 94 |
+
print(f"Processing batch of {len(images_and_indices)} S2P_MFIELD images for serial extraction")
|
| 95 |
+
|
| 96 |
+
results = {}
|
| 97 |
+
|
| 98 |
+
# Process each image
|
| 99 |
+
for pil_image, original_index in images_and_indices:
|
| 100 |
+
try:
|
| 101 |
+
serial_id = extract_single_serial_number(pil_image)
|
| 102 |
+
results[original_index] = serial_id
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error processing image at index {original_index}: {e}")
|
| 105 |
+
results[original_index] = ""
|
| 106 |
+
|
| 107 |
+
return results
|
| 108 |
+
|
| 109 |
+
def extract_serial_number(pil_image: Image.Image) -> str:
|
| 110 |
+
"""
|
| 111 |
+
Single image extraction function - backward compatibility
|
| 112 |
+
Based on the old working code logic
|
| 113 |
+
"""
|
| 114 |
+
global _pipeline_instance
|
| 115 |
+
|
| 116 |
+
# Initialize model if not already done (like old code)
|
| 117 |
+
if _pipeline_instance is None:
|
| 118 |
+
initialize_nanonets_model()
|
| 119 |
+
|
| 120 |
+
return extract_single_serial_number(pil_image)
|
| 121 |
+
|
| 122 |
+
def cleanup_gpu_cache():
|
| 123 |
+
"""Clean up GPU cache if needed"""
|
| 124 |
+
if torch.cuda.is_available():
|
| 125 |
+
torch.cuda.empty_cache()
|
single_inferencing_2.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 7 |
+
import yaml
|
| 8 |
+
sys.path.append(os.getcwd())
|
| 9 |
+
import re
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import base64
|
| 12 |
+
|
| 13 |
+
from openai import OpenAI
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from sklearn.metrics import classification_report
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from utils.image_utils_yolo import YOLOProcessor
|
| 18 |
+
from utils.image_utils_bbox import updated_add_bbox
|
| 19 |
+
from utils.image_utils import add_bbox, encode_pil_image_to_base64
|
| 20 |
+
from utils.prompt_utils import create_query, parse_label
|
| 21 |
+
from embedding_service import embedding_service
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def optimize_image_for_tokens(image, max_size=768):
|
| 25 |
+
"""
|
| 26 |
+
Simple image optimization to reduce token consumption
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
image: PIL Image object
|
| 30 |
+
max_size: Maximum dimension (768 = good balance of quality vs tokens)
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Optimized PIL Image
|
| 34 |
+
"""
|
| 35 |
+
original_size = image.size
|
| 36 |
+
|
| 37 |
+
# Only resize if image is larger than max_size
|
| 38 |
+
if max(image.size) > max_size:
|
| 39 |
+
# Calculate new size maintaining aspect ratio
|
| 40 |
+
ratio = max_size / max(image.size)
|
| 41 |
+
new_size = (
|
| 42 |
+
int(image.size[0] * ratio),
|
| 43 |
+
int(image.size[1] * ratio)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Resize with high-quality resampling
|
| 47 |
+
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
| 48 |
+
|
| 49 |
+
print(f"Image optimized: {original_size} → {new_size} (estimated {int((max(original_size)/max_size)**2 * 100)}% token reduction)", flush=True)
|
| 50 |
+
|
| 51 |
+
return image
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SingleImageInference:
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
server_url: str,
|
| 58 |
+
segmentation_device_id: Optional[int] = None,
|
| 59 |
+
log_dir: str = "inference_logs",
|
| 60 |
+
enable_bbox_detection: bool = True,
|
| 61 |
+
):
|
| 62 |
+
self.log_dir = log_dir
|
| 63 |
+
self.enable_bbox_detection = enable_bbox_detection
|
| 64 |
+
self.segmentation_device_id = segmentation_device_id
|
| 65 |
+
|
| 66 |
+
# initialise client
|
| 67 |
+
self.client = OpenAI(base_url=server_url, api_key="EMPTY")
|
| 68 |
+
|
| 69 |
+
# Ensure log directory exists
|
| 70 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
# Cache the available models to validate model existence
|
| 73 |
+
self.available_models = self._get_available_models()
|
| 74 |
+
|
| 75 |
+
# Load question mappings
|
| 76 |
+
self.question_mappings = self._load_question_mappings()
|
| 77 |
+
|
| 78 |
+
# Load document type groups
|
| 79 |
+
self.doctype_groups = self._load_doctype_groups()
|
| 80 |
+
|
| 81 |
+
#Load document - bbox class mapping
|
| 82 |
+
self.doctype_detection_mapping = self._load_detection_class_mapping()
|
| 83 |
+
|
| 84 |
+
self.yolo_processor = YOLOProcessor(device="cuda:0")
|
| 85 |
+
|
| 86 |
+
# Initialize CLIP model for embeddings
|
| 87 |
+
try:
|
| 88 |
+
embedding_service.load_model()
|
| 89 |
+
print("✅ CLIP model loaded successfully for embeddings", flush=True)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"⚠️ Warning: Failed to load CLIP model: {e}", flush=True)
|
| 92 |
+
|
| 93 |
+
def _get_available_models(self) -> List[str]:
|
| 94 |
+
for attempt in range(3):
|
| 95 |
+
try:
|
| 96 |
+
response = self.client.models.list()
|
| 97 |
+
return [model.id for model in response.data]
|
| 98 |
+
except Exception as e:
|
| 99 |
+
if attempt < 2: # Don't wait after the last attempt
|
| 100 |
+
wait_time = (attempt + 1) * 2 # 2s, then 4s
|
| 101 |
+
print(f"Attempt {attempt + 1} failed. Retrying in {wait_time}s...")
|
| 102 |
+
time.sleep(wait_time)
|
| 103 |
+
else:
|
| 104 |
+
print(f"Warning: Could not fetch available models: {str(e)}")
|
| 105 |
+
return []
|
| 106 |
+
|
| 107 |
+
def _load_question_mappings(self) -> Dict[str, Dict[str, str]]:
|
| 108 |
+
"""Load question mappings from YAML file."""
|
| 109 |
+
try:
|
| 110 |
+
mapping_file = "/app/meta-jv-reasoning/Teleco-Pilot-Use-Case/framework/utils/question_mappings.yaml"
|
| 111 |
+
with open(mapping_file, 'r') as f:
|
| 112 |
+
return yaml.safe_load(f)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error loading question mappings: {str(e)}")
|
| 115 |
+
return {}
|
| 116 |
+
|
| 117 |
+
def _load_doctype_groups(self) -> Dict[str, List[str]]:
|
| 118 |
+
"""Load document type groups from YAML file."""
|
| 119 |
+
try:
|
| 120 |
+
groups_file = "/app/meta-jv-reasoning/Teleco-Pilot-Use-Case/framework/utils/doctype_groups.yaml"
|
| 121 |
+
with open(groups_file, 'r') as f:
|
| 122 |
+
return yaml.safe_load(f)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"Error loading document type groups: {str(e)}")
|
| 125 |
+
return {}
|
| 126 |
+
|
| 127 |
+
def _get_parent_doctype(self, doc_type: str) -> str:
|
| 128 |
+
"""Get the parent document type for a given document type."""
|
| 129 |
+
# Ensure doc_type is uppercase for consistency
|
| 130 |
+
doc_type = doc_type.upper()
|
| 131 |
+
|
| 132 |
+
# Search through the groups to find the parent
|
| 133 |
+
for parent_type, child_types in self.doctype_groups.items():
|
| 134 |
+
if doc_type in child_types:
|
| 135 |
+
return parent_type
|
| 136 |
+
|
| 137 |
+
# If no parent found, return the original type
|
| 138 |
+
return doc_type
|
| 139 |
+
|
| 140 |
+
def _load_detection_class_mapping(self) -> Dict[str, List[List[str]]]:
|
| 141 |
+
"""Load and cache detection class mapping from YAML file."""
|
| 142 |
+
try:
|
| 143 |
+
mapping_file = "/app/meta-jv-reasoning/Teleco-Pilot-Use-Case/framework/utils/doctype_bbox_detection_mapping.yaml"
|
| 144 |
+
with open(mapping_file, "r") as f:
|
| 145 |
+
return yaml.safe_load(f)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error loading detection class mapping: {str(e)}")
|
| 148 |
+
return {}
|
| 149 |
+
|
| 150 |
+
def _get_detection_classes_for_doctype(self, parent_doctype: str) -> Optional[List[str]]:
|
| 151 |
+
"""Return list of detection classes for a given parent doctype, or None if not found."""
|
| 152 |
+
parent_doctype = parent_doctype.upper()
|
| 153 |
+
class_groups = self.doctype_detection_mapping.get(parent_doctype)
|
| 154 |
+
|
| 155 |
+
if not class_groups:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
return [cls for group in class_groups for cls in group]
|
| 159 |
+
|
| 160 |
+
def _map_question(self, question: str, doc_type: str = "default") -> str:
|
| 161 |
+
"""Map complex questions to simpler versions based on document type."""
|
| 162 |
+
# Ensure doc_type is uppercase for consistency
|
| 163 |
+
doc_type = doc_type.upper()
|
| 164 |
+
|
| 165 |
+
# Map the document type to its parent type
|
| 166 |
+
parent_doc_type = self._get_parent_doctype(doc_type)
|
| 167 |
+
|
| 168 |
+
# Get mappings for the specific document type, fallback to default if not found
|
| 169 |
+
mappings = self.question_mappings.get(parent_doc_type, self.question_mappings.get("default", {}))
|
| 170 |
+
print(f"Mapping question for doc_type: {doc_type} (parent: {parent_doc_type})")
|
| 171 |
+
|
| 172 |
+
# Clean up the question by removing any numbering prefix
|
| 173 |
+
clean_question = re.sub(r'^\d+\.\s*', '', question)
|
| 174 |
+
|
| 175 |
+
# Try to find an exact match in the mappings
|
| 176 |
+
for complex_q, simple_q in mappings.items():
|
| 177 |
+
# Remove numbering from complex question for comparison
|
| 178 |
+
clean_complex_q = re.sub(r'^\d+\.\s*', '', complex_q)
|
| 179 |
+
if clean_question.startswith(clean_complex_q):
|
| 180 |
+
return simple_q
|
| 181 |
+
|
| 182 |
+
# If no match found, return the original question
|
| 183 |
+
return question
|
| 184 |
+
|
| 185 |
+
def _extract_reasoning(self, raw_response: Dict, doc_type: str = "default") -> str:
|
| 186 |
+
"""Extract and format reasoning from raw response into a single string."""
|
| 187 |
+
# Map the document type to its parent type
|
| 188 |
+
parent_doc_type = self._get_parent_doctype(doc_type)
|
| 189 |
+
|
| 190 |
+
reasoning = raw_response.get("reasoning", [])
|
| 191 |
+
if isinstance(reasoning, list):
|
| 192 |
+
formatted_reasoning = []
|
| 193 |
+
for i, item in enumerate(reasoning, 1):
|
| 194 |
+
if ": " in item:
|
| 195 |
+
question, answer = item.split(": ", 1)
|
| 196 |
+
mapped_question = self._map_question(question, parent_doc_type)
|
| 197 |
+
formatted_reasoning.append(f"{i}. {mapped_question}: {answer}")
|
| 198 |
+
else:
|
| 199 |
+
formatted_reasoning.append(f"{i}. {item}")
|
| 200 |
+
return "\n".join(formatted_reasoning)
|
| 201 |
+
return str(reasoning)
|
| 202 |
+
|
| 203 |
+
def run_inference(self, query, temperature: float) -> Dict[str, str]:
|
| 204 |
+
"""Run inference on a single image using the vLLM server."""
|
| 205 |
+
doc_type = query.get("doc_type", "default").upper()
|
| 206 |
+
parent_doc_type = self._get_parent_doctype(doc_type)
|
| 207 |
+
print(f"<-- parent_doc_type : {parent_doc_type} -->")
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
if isinstance(query["image"], Image.Image):
|
| 211 |
+
image = query["image"].convert("RGB")
|
| 212 |
+
else:
|
| 213 |
+
image_data = base64.b64decode(query["image"])
|
| 214 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"Error decoding base64 image: {str(e)}")
|
| 217 |
+
raise ValueError("Invalid image format in query")
|
| 218 |
+
|
| 219 |
+
# Extract embedding from original image before any bbox processing
|
| 220 |
+
try:
|
| 221 |
+
embedding = embedding_service.extract_embedding(image)
|
| 222 |
+
print("✅ Embedding extracted successfully", flush=True)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
print(f"⚠️ Warning: Failed to extract embedding: {e}", flush=True)
|
| 225 |
+
embedding = None
|
| 226 |
+
|
| 227 |
+
image = optimize_image_for_tokens(image, max_size=768)
|
| 228 |
+
# Existing bbox detection logic
|
| 229 |
+
if self.enable_bbox_detection:
|
| 230 |
+
try:
|
| 231 |
+
detection_classes = self._get_detection_classes_for_doctype(parent_doc_type)
|
| 232 |
+
print(f"<< DETECTION CLASS : {detection_classes} >>", flush=True)
|
| 233 |
+
|
| 234 |
+
if detection_classes:
|
| 235 |
+
image_with_boxes = self.yolo_processor.process_bbox(
|
| 236 |
+
image, desired_classes=detection_classes
|
| 237 |
+
)
|
| 238 |
+
if image_with_boxes is not None:
|
| 239 |
+
image = image_with_boxes
|
| 240 |
+
else:
|
| 241 |
+
print(f"No detection class mapping found for parent doctype: {parent_doc_type}", flush=True)
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Error applying bounding boxes: {str(e)}", flush=True)
|
| 244 |
+
|
| 245 |
+
# Rest of your existing inference code
|
| 246 |
+
image_b64 = encode_pil_image_to_base64(image)
|
| 247 |
+
instruction = f"{query['task_instruction']}\n\n{query['format_instruction']}"
|
| 248 |
+
|
| 249 |
+
if not self.available_models:
|
| 250 |
+
self.available_models = self._get_available_models()
|
| 251 |
+
if not self.available_models:
|
| 252 |
+
raise ValueError(
|
| 253 |
+
"No models available on the server. Please ensure the vLLM server is running and accessible."
|
| 254 |
+
)
|
| 255 |
+
model_id = self.available_models[0]
|
| 256 |
+
|
| 257 |
+
response = self.client.chat.completions.create(
|
| 258 |
+
model=model_id,
|
| 259 |
+
messages=[
|
| 260 |
+
{
|
| 261 |
+
"role": "user",
|
| 262 |
+
"content": [
|
| 263 |
+
{"type": "text", "text": instruction},
|
| 264 |
+
{
|
| 265 |
+
"type": "image_url",
|
| 266 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
|
| 267 |
+
},
|
| 268 |
+
],
|
| 269 |
+
}
|
| 270 |
+
],
|
| 271 |
+
max_tokens=512,
|
| 272 |
+
temperature=temperature,
|
| 273 |
+
top_p=0.95,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
raw_response = response.choices[0].message.content
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
classification_label = parse_label(raw_response)
|
| 280 |
+
success = True
|
| 281 |
+
except AttributeError:
|
| 282 |
+
classification_label = "UNKNOWN"
|
| 283 |
+
success = False
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
json_str = raw_response.strip("`json\n")
|
| 287 |
+
raw_response_dict = json.loads(json_str)
|
| 288 |
+
formatted_reasoning = self._extract_reasoning(raw_response_dict, doc_type)
|
| 289 |
+
raw_response_dict["reasoning"] = formatted_reasoning
|
| 290 |
+
|
| 291 |
+
# Add embedding to response if available
|
| 292 |
+
if embedding is not None:
|
| 293 |
+
raw_response_dict["embedding"] = embedding
|
| 294 |
+
|
| 295 |
+
raw_response = json.dumps(raw_response_dict)
|
| 296 |
+
except Exception as e:
|
| 297 |
+
print(f"Error formatting reasoning: {str(e)}")
|
| 298 |
+
response_dict = {
|
| 299 |
+
"reasoning": str(raw_response),
|
| 300 |
+
"evaluation_result": classification_label
|
| 301 |
+
}
|
| 302 |
+
if embedding is not None:
|
| 303 |
+
response_dict["embedding"] = embedding
|
| 304 |
+
raw_response = json.dumps(response_dict)
|
| 305 |
+
|
| 306 |
+
return {
|
| 307 |
+
"response": raw_response,
|
| 308 |
+
"label": classification_label,
|
| 309 |
+
"success": success,
|
| 310 |
+
}
|