| import os |
| import sys |
| import json |
| import base64 |
| import asyncio |
| import concurrent.futures |
| from typing import Dict, Optional, List, Union |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import uvicorn |
| from PIL import Image |
| import io |
| from contextlib import asynccontextmanager |
| from prometheus_fastapi_instrumentator import Instrumentator |
|
|
| |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
| from single_inferencing_2 import SingleImageInference |
| from utils.prompt_utils import create_query, parse_label, create_query_updated |
| from utils.image_utils import encode_pil_image_to_base64 |
|
|
| |
| LOG_DIR = os.getenv("LOG_DIR", "inference_logs") |
| SEGMENTATION_DEVICE_ID = int(os.getenv("SEGMENTATION_DEVICE_ID", "7")) |
| ENABLE_BBOX_DETECTION = os.getenv("ENABLE_BBOX_DETECTION", "False").lower() == "true" |
| VLLM_SERVER_URL: Optional[str] = None |
| MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "10")) |
| MAX_CONCURRENT_WORKERS = int(os.getenv("MAX_CONCURRENT_WORKERS", "4")) |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """ |
| Handles startup and shutdown events for the FastAPI application. |
| Initializes the inferencer during startup. |
| """ |
| global VLLM_SERVER_URL |
|
|
| if VLLM_SERVER_URL is None: |
| print("ERROR: VLLM_SERVER_URL was not set before lifespan start. Exiting.", flush=True) |
| sys.exit(1) |
|
|
| print(f"Lifespan: Initializing inferencer for this worker with VLLM URL: {VLLM_SERVER_URL}", flush=True) |
| try: |
| app.state.inferencer = SingleImageInference( |
| server_url=VLLM_SERVER_URL, |
| log_dir=LOG_DIR, |
| segmentation_device_id=SEGMENTATION_DEVICE_ID, |
| enable_bbox_detection=True |
| ) |
| |
| |
| app.state.thread_pool = concurrent.futures.ThreadPoolExecutor( |
| max_workers=MAX_CONCURRENT_WORKERS |
| ) |
| |
| print("Lifespan: Inferencer and thread pool successfully initialized.", flush=True) |
| except Exception as e: |
| print(f"Lifespan ERROR: Failed to initialize Inferencer: {e}", flush=True) |
| app.state.inferencer = None |
| app.state.thread_pool = None |
| yield |
| |
| |
| print("Lifespan: Application shutdown. Performing cleanup.", flush=True) |
| if hasattr(app.state, 'thread_pool') and app.state.thread_pool: |
| app.state.thread_pool.shutdown(wait=True) |
| if hasattr(app.state.inferencer, 'close'): |
| app.state.inferencer.close() |
|
|
| |
| app = FastAPI( |
| title="Llama Inferencing API with Batch Processing", |
| description="API for running inference on images using Llama model - supports both single and batch processing", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| Instrumentator().instrument(app).expose(app) |
|
|
| |
| class InferenceRequest(BaseModel): |
| data: List[Dict[str, Union[str, float]]] |
|
|
| class BatchInferenceRequest(BaseModel): |
| data: List[Dict[str, Union[str, float]]] |
| batch_size: Optional[int] = None |
|
|
| class InferenceResponse(BaseModel): |
| body: Dict |
| meta: Dict |
| error: str |
|
|
| class BatchInferenceResponse(BaseModel): |
| body: Dict |
| meta: Dict |
| error: str |
| batch_info: Dict |
|
|
| def process_single_item(inferencer, item: Dict, temp_dir: str = "/tmp") -> Dict: |
| """ |
| Process a single inference item - extracted for reuse in batch processing |
| """ |
| try: |
| |
| workorder_id = item["workorder_id"] |
| image_id = item["image_id"] |
| doc_type = item["doc_type"] |
| business_type = item["business_type"] |
| workorder_type = item["workorder_type"] |
| image_base64 = item["image"] |
|
|
| |
| image_content = base64.b64decode(image_base64) |
| pil_image = Image.open(io.BytesIO(image_content)) |
|
|
| |
| temp_image_path = f"{temp_dir}/{image_id}_{workorder_id}.jpg" |
| pil_image.save(temp_image_path) |
|
|
| |
| query = create_query_updated( |
| temp_image_path, |
| doc_type.lower(), |
| [item.get("task_name", "default")], |
| [item.get("format_name", "reasoning_specrec")] |
| )[0] |
|
|
| query["image"] = pil_image |
| query["doc_type"] = doc_type.upper() |
|
|
| print(f"Processing WORKORDERID: {workorder_id}, DOCTYPE: {query['doc_type']}", flush=True) |
|
|
| |
| inference_result = inferencer.run_inference(query, item.get("temperature", 0.1)) |
|
|
| |
| try: |
| json_str = inference_result["response"].strip("`json\n") |
| raw_response = json.loads(json_str) |
| except Exception as e: |
| print(f"Failed to parse model response: {e}. Raw response: {inference_result.get('response')}", flush=True) |
| raw_response = { |
| "reasoning": "Failed to parse model response", |
| "evaluation_result": "UNKNOWN" |
| } |
|
|
| evaluation_result = raw_response.get("evaluation_result", "UNCERTAIN") |
|
|
| |
| if evaluation_result == "VALID": |
| model_decision = "VALID_INSTALL" |
| review_queue = "GREEN" |
| elif evaluation_result == "INVALID": |
| model_decision = "INVALID_INSTALL" |
| review_queue = "RED" |
| else: |
| model_decision = "UNCERTAIN" |
| review_queue = "YELLOW" |
|
|
| |
| embedding = raw_response.get("embedding") |
|
|
| formatted_result = { |
| "workorder_id": workorder_id, |
| "image_id": image_id, |
| "doc_type": doc_type, |
| "business_type": business_type, |
| "workorder_type": workorder_type, |
| "confidence_threshold": 0, |
| "model_output": { |
| "model_decision_reason": raw_response.get("reasoning", ""), |
| "model_decision": model_decision, |
| "recommendation": raw_response.get("recommendations", ""), |
| |
| "serial_id": "12345", |
| "power_meter_reading": raw_response.get("power_meter_reading", ""), |
| "review_queue": review_queue, |
| "confidence_score": 0, |
| } |
| } |
|
|
| |
| if embedding is not None: |
| formatted_result["embedding"] = embedding |
|
|
| |
| if os.path.exists(temp_image_path): |
| os.remove(temp_image_path) |
|
|
| return {"success": True, "result": formatted_result, "error": None} |
|
|
| except Exception as e: |
| |
| if 'temp_image_path' in locals() and os.path.exists(temp_image_path): |
| os.remove(temp_image_path) |
| |
| print(f"Error processing item {item.get('workorder_id', 'unknown')}: {e}", flush=True) |
| return {"success": False, "result": None, "error": str(e)} |
|
|
| async def process_batch_chunk(inferencer, chunk: List[Dict], executor) -> List[Dict]: |
| """ |
| Process a chunk of items concurrently using thread pool |
| """ |
| loop = asyncio.get_event_loop() |
| futures = [ |
| loop.run_in_executor(executor, process_single_item, inferencer, item) |
| for item in chunk |
| ] |
| return await asyncio.gather(*futures) |
|
|
| @app.post("/infer/", response_model=InferenceResponse) |
| async def run_inference(request: InferenceRequest): |
| """ |
| Run inference on a single image and return the results. |
| """ |
| if app.state.inferencer is None: |
| raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.") |
| |
| try: |
| item = request.data[0] |
| result = process_single_item(app.state.inferencer, item) |
| |
| if result["success"]: |
| return { |
| "body": {"data": [result["result"]]}, |
| "meta": {}, |
| "error": "" |
| } |
| else: |
| return { |
| "body": {"data": []}, |
| "meta": {}, |
| "error": result["error"] |
| } |
| except Exception as e: |
| print(f"API - Error during inference: {e}", flush=True) |
| return { |
| "body": {"data": []}, |
| "meta": {}, |
| "error": str(e) |
| } |
|
|
| @app.post("/infer/batch/", response_model=BatchInferenceResponse) |
| async def run_batch_inference(request: BatchInferenceRequest): |
| """ |
| Run inference on multiple images in batches with concurrent processing. |
| """ |
| if app.state.inferencer is None: |
| raise HTTPException(status_code=500, detail="Inferencer not initialized or failed to load.") |
| |
| if app.state.thread_pool is None: |
| raise HTTPException(status_code=500, detail="Thread pool not initialized.") |
| |
| try: |
| batch_size = request.batch_size or MAX_BATCH_SIZE |
| data = request.data |
| |
| |
| if len(data) > MAX_BATCH_SIZE * 5: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Batch too large. Maximum allowed: {MAX_BATCH_SIZE * 5}, received: {len(data)}" |
| ) |
| |
| print(f"Processing batch of {len(data)} items with batch_size={batch_size}", flush=True) |
| |
| |
| chunks = [data[i:i + batch_size] for i in range(0, len(data), batch_size)] |
| |
| all_results = [] |
| successful_count = 0 |
| failed_count = 0 |
| |
| |
| for i, chunk in enumerate(chunks): |
| print(f"Processing chunk {i + 1}/{len(chunks)} with {len(chunk)} items", flush=True) |
| |
| chunk_results = await process_batch_chunk( |
| app.state.inferencer, |
| chunk, |
| app.state.thread_pool |
| ) |
| |
| |
| for result in chunk_results: |
| if result["success"]: |
| all_results.append(result["result"]) |
| successful_count += 1 |
| else: |
| failed_count += 1 |
| print(f"Failed to process item: {result['error']}", flush=True) |
| |
| batch_info = { |
| "total_items": len(data), |
| "successful_items": successful_count, |
| "failed_items": failed_count, |
| "batch_size_used": batch_size, |
| "total_chunks": len(chunks) |
| } |
| |
| return { |
| "body": {"data": all_results}, |
| "meta": {"processing_time": "completed"}, |
| "error": f"{failed_count} items failed" if failed_count > 0 else "", |
| "batch_info": batch_info |
| } |
| |
| except Exception as e: |
| print(f"API - Error during batch inference: {e}", flush=True) |
| return { |
| "body": {"data": []}, |
| "meta": {}, |
| "error": str(e), |
| "batch_info": {"total_items": len(request.data), "successful_items": 0, "failed_items": len(request.data)} |
| } |
|
|
| @app.get("/health") |
| async def health_check(): |
| """ |
| Health check endpoint. |
| """ |
| if app.state.inferencer is None: |
| raise HTTPException(status_code=503, detail="Inferencer not initialized or failed to load") |
| |
| if app.state.thread_pool is None: |
| raise HTTPException(status_code=503, detail="Thread pool not initialized") |
| |
| return { |
| "status": "healthy", |
| "max_batch_size": MAX_BATCH_SIZE, |
| "max_concurrent_workers": MAX_CONCURRENT_WORKERS |
| } |
|
|
| @app.get("/") |
| async def root(): |
| """ |
| Root endpoint for basic health check. |
| """ |
| return { |
| "status": "API is running", |
| "service": "Llama Inferencing API with Batch Processing", |
| "endpoints": { |
| "single_inference": "/infer/", |
| "batch_inference": "/infer/batch/", |
| "health": "/health" |
| } |
| } |
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--port", type=int, default=8877, help="API port") |
| parser.add_argument("--vllm-url", type=str, default="http://localhost:8000/v1", help="VLLM server URL") |
| parser.add_argument("--max-batch-size", type=int, default=10, help="Maximum batch size") |
| parser.add_argument("--max-workers", type=int, default=4, help="Maximum concurrent workers") |
|
|
| args = parser.parse_args() |
|
|
| |
| VLLM_SERVER_URL = args.vllm_url |
| MAX_BATCH_SIZE = args.max_batch_size |
| MAX_CONCURRENT_WORKERS = args.max_workers |
|
|
| print(f"Starting API server on port {args.port}", flush=True) |
| print(f"VLLM URL: {args.vllm_url}", flush=True) |
| print(f"Max batch size: {MAX_BATCH_SIZE}", flush=True) |
| print(f"Max concurrent workers: {MAX_CONCURRENT_WORKERS}", flush=True) |
| |
| uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False) |
|
|