import os import sys import json import base64 import asyncio import concurrent.futures from typing import Dict, Optional, List, Union from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn from PIL import Image import io from contextlib import asynccontextmanager from prometheus_fastapi_instrumentator import Instrumentator # Add the current directory to the path so we can import the llama_inferencing module sys.path.append(os.path.dirname(os.path.abspath(__file__))) from single_inferencing_2 import SingleImageInference from utils.prompt_utils import create_query, parse_label, create_query_updated from utils.image_utils import encode_pil_image_to_base64 # --- GLOBAL VARS (Constants, not the inferencer itself) --- LOG_DIR = os.getenv("LOG_DIR", "inference_logs") SEGMENTATION_DEVICE_ID = int(os.getenv("SEGMENTATION_DEVICE_ID", "7")) ENABLE_BBOX_DETECTION = os.getenv("ENABLE_BBOX_DETECTION", "False").lower() == "true" VLLM_SERVER_URL: Optional[str] = None MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "10")) # Maximum batch size MAX_CONCURRENT_WORKERS = int(os.getenv("MAX_CONCURRENT_WORKERS", "4")) # Concurrent processing limit # --- Lifespan Context Manager --- @asynccontextmanager async def lifespan(app: FastAPI): """ Handles startup and shutdown events for the FastAPI application. Initializes the inferencer during startup. """ global VLLM_SERVER_URL if VLLM_SERVER_URL is None: print("ERROR: VLLM_SERVER_URL was not set before lifespan start. Exiting.", flush=True) sys.exit(1) print(f"Lifespan: Initializing inferencer for this worker with VLLM URL: {VLLM_SERVER_URL}", flush=True) try: app.state.inferencer = SingleImageInference( server_url=VLLM_SERVER_URL, log_dir=LOG_DIR, segmentation_device_id=SEGMENTATION_DEVICE_ID, enable_bbox_detection=True ) # Initialize thread pool for batch processing app.state.thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=MAX_CONCURRENT_WORKERS ) print("Lifespan: Inferencer and thread pool successfully initialized.", flush=True) except Exception as e: print(f"Lifespan ERROR: Failed to initialize Inferencer: {e}", flush=True) app.state.inferencer = None app.state.thread_pool = None yield # Shutdown cleanup print("Lifespan: Application shutdown. Performing cleanup.", flush=True) if hasattr(app.state, 'thread_pool') and app.state.thread_pool: app.state.thread_pool.shutdown(wait=True) if hasattr(app.state.inferencer, 'close'): app.state.inferencer.close() # Initialize FastAPI app with lifespan app = FastAPI( title="Llama Inferencing API with Batch Processing", description="API for running inference on images using Llama model - supports both single and batch processing", lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) Instrumentator().instrument(app).expose(app) # --- BaseModel Definitions --- class InferenceRequest(BaseModel): data: List[Dict[str, Union[str, float]]] class BatchInferenceRequest(BaseModel): data: List[Dict[str, Union[str, float]]] batch_size: Optional[int] = None # Optional batch size override class InferenceResponse(BaseModel): body: Dict meta: Dict error: str class BatchInferenceResponse(BaseModel): body: Dict meta: Dict error: str batch_info: Dict # Additional batch processing info def process_single_item(inferencer, item: Dict, temp_dir: str = "/tmp") -> Dict: """ Process a single inference item - extracted for reuse in batch processing """ try: # Extract fields from the item workorder_id = item["workorder_id"] image_id = item["image_id"] doc_type = item["doc_type"] business_type = item["business_type"] workorder_type = item["workorder_type"] image_base64 = item["image"] # Decode the base64 image image_content = base64.b64decode(image_base64) pil_image = Image.open(io.BytesIO(image_content)) # Create a temporary file path for the image temp_image_path = f"{temp_dir}/{image_id}_{workorder_id}.jpg" pil_image.save(temp_image_path) # Create query for the image query = create_query_updated( temp_image_path, doc_type.lower(), [item.get("task_name", "default")], [item.get("format_name", "reasoning_specrec")] )[0] query["image"] = pil_image query["doc_type"] = doc_type.upper() print(f"Processing WORKORDERID: {workorder_id}, DOCTYPE: {query['doc_type']}", flush=True) # Run inference using the initialized inferencer inference_result = inferencer.run_inference(query, item.get("temperature", 0.1)) # Parse the response try: json_str = inference_result["response"].strip("`json\n") raw_response = json.loads(json_str) except Exception as e: print(f"Failed to parse model response: {e}. Raw response: {inference_result.get('response')}", flush=True) raw_response = { "reasoning": "Failed to parse model response", "evaluation_result": "UNKNOWN" } evaluation_result = raw_response.get("evaluation_result", "UNCERTAIN") # Normalize model_decision if evaluation_result == "VALID": model_decision = "VALID_INSTALL" review_queue = "GREEN" elif evaluation_result == "INVALID": model_decision = "INVALID_INSTALL" review_queue = "RED" else: model_decision = "UNCERTAIN" review_queue = "YELLOW" # Extract embedding from raw_response if available embedding = raw_response.get("embedding") formatted_result = { "workorder_id": workorder_id, "image_id": image_id, "doc_type": doc_type, "business_type": business_type, "workorder_type": workorder_type, "confidence_threshold": 0, "model_output": { "model_decision_reason": raw_response.get("reasoning", ""), "model_decision": model_decision, "recommendation": raw_response.get("recommendations", ""), # "serial_id": raw_response.get("serial_id", ""), "serial_id": "12345", "power_meter_reading": raw_response.get("power_meter_reading", ""), "review_queue": review_queue, "confidence_score": 0, } } # Add embedding to response if available if embedding is not None: formatted_result["embedding"] = embedding # Clean up the temporary file if os.path.exists(temp_image_path): os.remove(temp_image_path) return {"success": True, "result": formatted_result, "error": None} except Exception as e: # Clean up the temporary file in case of error if 'temp_image_path' in locals() and os.path.exists(temp_image_path): os.remove(temp_image_path) print(f"Error processing item {item.get('workorder_id', 'unknown')}: {e}", flush=True) return {"success": False, "result": None, "error": str(e)} async def process_batch_chunk(inferencer, chunk: List[Dict], executor) -> List[Dict]: """ Process a chunk of items concurrently using thread pool """ loop = asyncio.get_event_loop() futures = [ loop.run_in_executor(executor, process_single_item, inferencer, item) for item in chunk ] return await asyncio.gather(*futures) @app.post("/infer/", response_model=InferenceResponse) async def run_inference(request: InferenceRequest): """ Run inference on a single image and return the results. """ if app.state.inferencer is None: raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.") try: item = request.data[0] result = process_single_item(app.state.inferencer, item) if result["success"]: return { "body": {"data": [result["result"]]}, "meta": {}, "error": "" } else: return { "body": {"data": []}, "meta": {}, "error": result["error"] } except Exception as e: print(f"API - Error during inference: {e}", flush=True) return { "body": {"data": []}, "meta": {}, "error": str(e) } @app.post("/infer/batch/", response_model=BatchInferenceResponse) async def run_batch_inference(request: BatchInferenceRequest): """ Run inference on multiple images in batches with concurrent processing. """ if app.state.inferencer is None: raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.") if app.state.thread_pool is None: raise HTTPException(status_code=500, detail="Thread pool not initialized.") try: batch_size = request.batch_size or MAX_BATCH_SIZE data = request.data # Validate batch size if len(data) > MAX_BATCH_SIZE * 5: # Allow up to 5x max batch size raise HTTPException( status_code=400, detail=f"Batch too large. Maximum allowed: {MAX_BATCH_SIZE * 5}, received: {len(data)}" ) print(f"Processing batch of {len(data)} items with batch_size={batch_size}", flush=True) # Split data into chunks chunks = [data[i:i + batch_size] for i in range(0, len(data), batch_size)] all_results = [] successful_count = 0 failed_count = 0 # Process chunks sequentially to avoid overwhelming the system for i, chunk in enumerate(chunks): print(f"Processing chunk {i + 1}/{len(chunks)} with {len(chunk)} items", flush=True) chunk_results = await process_batch_chunk( app.state.inferencer, chunk, app.state.thread_pool ) # Collect results and count successes/failures for result in chunk_results: if result["success"]: all_results.append(result["result"]) successful_count += 1 else: failed_count += 1 print(f"Failed to process item: {result['error']}", flush=True) batch_info = { "total_items": len(data), "successful_items": successful_count, "failed_items": failed_count, "batch_size_used": batch_size, "total_chunks": len(chunks) } return { "body": {"data": all_results}, "meta": {"processing_time": "completed"}, "error": f"{failed_count} items failed" if failed_count > 0 else "", "batch_info": batch_info } except Exception as e: print(f"API - Error during batch inference: {e}", flush=True) return { "body": {"data": []}, "meta": {}, "error": str(e), "batch_info": {"total_items": len(request.data), "successful_items": 0, "failed_items": len(request.data)} } @app.get("/health") async def health_check(): """ Health check endpoint. """ if app.state.inferencer is None: raise HTTPException(status_code=503, detail="Inferencer not initialized or failed to load") if app.state.thread_pool is None: raise HTTPException(status_code=503, detail="Thread pool not initialized") return { "status": "healthy", "max_batch_size": MAX_BATCH_SIZE, "max_concurrent_workers": MAX_CONCURRENT_WORKERS } @app.get("/") async def root(): """ Root endpoint for basic health check. """ return { "status": "API is running", "service": "Llama Inferencing API with Batch Processing", "endpoints": { "single_inference": "/infer/", "batch_inference": "/infer/batch/", "health": "/health" } } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8877, help="API port") parser.add_argument("--vllm-url", type=str, default="http://localhost:8000/v1", help="VLLM server URL") parser.add_argument("--max-batch-size", type=int, default=10, help="Maximum batch size") parser.add_argument("--max-workers", type=int, default=4, help="Maximum concurrent workers") args = parser.parse_args() # Store configuration globally VLLM_SERVER_URL = args.vllm_url MAX_BATCH_SIZE = args.max_batch_size MAX_CONCURRENT_WORKERS = args.max_workers print(f"Starting API server on port {args.port}", flush=True) print(f"VLLM URL: {args.vllm_url}", flush=True) print(f"Max batch size: {MAX_BATCH_SIZE}", flush=True) print(f"Max concurrent workers: {MAX_CONCURRENT_WORKERS}", flush=True) uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False)