Spaces:
Sleeping
Sleeping
File size: 11,553 Bytes
d2a2955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
"""
FastAPI application for QuickDraw sketch recognition.
Exposes API endpoints for VR/AR applications to classify drawings.
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
import logging
import os
import base64
from datetime import datetime
from pathlib import Path
import json
from model import SketchClassifier
from utils import preprocess_image_from_bytes, preprocess_image_from_base64
# Configure comprehensive logging
LOG_DIR = "api_logs"
IMAGES_LOG_DIR = os.path.join(LOG_DIR, "received_images")
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(IMAGES_LOG_DIR, exist_ok=True)
# Setup logging to both file and console
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.join(LOG_DIR, 'api.log')),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Create separate logger for request details
request_logger = logging.getLogger("requests")
request_handler = logging.FileHandler(os.path.join(LOG_DIR, 'requests_detailed.log'))
request_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
request_logger.addHandler(request_handler)
request_logger.setLevel(logging.INFO)
# Initialize FastAPI app
app = FastAPI(
title="QuickDraw Sketch Recognition API",
description="API for recognizing hand-drawn sketches (house, cat, dog, car) for VR/AR applications",
version="1.0.0"
)
# CORS middleware - adjust origins based on your VR application needs
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your VR app's origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize model (singleton)
classifier = None
class PredictionRequest(BaseModel):
"""Request model for base64 encoded image"""
image_base64: str
top_k: Optional[int] = 3
class PredictionResponse(BaseModel):
"""Response model for predictions"""
predictions: List[dict]
success: bool
message: Optional[str] = None
@app.on_event("startup")
async def startup_event():
"""Load the model on startup"""
global classifier
try:
logger.info("Loading QuickDraw model...")
classifier = SketchClassifier()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "QuickDraw Sketch Recognition API",
"version": "1.0.0",
"endpoints": {
"/health": "Health check",
"/predict": "Predict from uploaded image file (POST)",
"/predict/base64": "Predict from base64 encoded image (POST)",
"/classes": "Get list of supported classes (GET)"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
model_loaded = classifier is not None
return {
"status": "healthy" if model_loaded else "unhealthy",
"model_loaded": model_loaded
}
@app.get("/classes")
async def get_classes():
"""Get list of supported drawing classes"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {
"classes": classifier.class_names,
"num_classes": len(classifier.class_names)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict_from_file(
file: UploadFile = File(...),
top_k: int = 3,
http_request: Request = None
):
"""
Predict drawing class from uploaded image file.
Args:
file: Image file (PNG, JPG, etc.)
top_k: Number of top predictions to return (default: 3)
http_request: FastAPI request object for logging
Returns:
PredictionResponse with top predictions and confidence scores
"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Generate unique request ID
request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
logger.info(f"="*80)
logger.info(f"[FILE-REQUEST {request_id}] New file upload prediction")
logger.info(f"[FILE-REQUEST {request_id}] Filename: {file.filename}")
logger.info(f"[FILE-REQUEST {request_id}] Content-Type: {file.content_type}")
logger.info(f"[FILE-REQUEST {request_id}] Top K: {top_k}")
try:
# Read image bytes
image_bytes = await file.read()
logger.info(f"[FILE-REQUEST {request_id}] File size: {len(image_bytes)} bytes")
# Save uploaded file
uploaded_file = os.path.join(IMAGES_LOG_DIR, f"uploaded_{request_id}_{file.filename}")
with open(uploaded_file, 'wb') as f:
f.write(image_bytes)
logger.info(f"[FILE-REQUEST {request_id}] File saved to: {uploaded_file}")
# Preprocess image
logger.info(f"[FILE-REQUEST {request_id}] Preprocessing image...")
processed_image = preprocess_image_from_bytes(image_bytes)
logger.info(f"[FILE-REQUEST {request_id}] Preprocessed shape: {processed_image.shape}")
# Make prediction
logger.info(f"[FILE-REQUEST {request_id}] Running inference...")
predictions = classifier.predict(processed_image, top_k=top_k)
# Log predictions
logger.info(f"[FILE-REQUEST {request_id}] PREDICTIONS:")
for i, pred in enumerate(predictions, 1):
logger.info(f"[FILE-REQUEST {request_id}] {i}. {pred['class']}: {pred['confidence_percent']}")
logger.info(f"[FILE-REQUEST {request_id}] ✓ Success")
logger.info(f"="*80)
return PredictionResponse(
predictions=predictions,
success=True,
message=f"Prediction successful (Request ID: {request_id})"
)
except Exception as e:
logger.error(f"[FILE-REQUEST {request_id}] ✗ FAILED: {e}")
logger.info(f"="*80)
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.post("/predict/base64", response_model=PredictionResponse)
async def predict_from_base64(request: PredictionRequest, http_request: Request):
"""
Predict drawing class from base64 encoded image.
Ideal for VR/AR applications sending image data directly.
Args:
request: PredictionRequest containing base64 image and optional top_k
http_request: FastAPI request object for logging
Returns:
PredictionResponse with top predictions and confidence scores
"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Generate unique request ID
request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# Log incoming request details
logger.info(f"="*80)
logger.info(f"[REQUEST {request_id}] New prediction request from VR")
logger.info(f"[REQUEST {request_id}] Client: {http_request.client.host}:{http_request.client.port}")
logger.info(f"[REQUEST {request_id}] User-Agent: {http_request.headers.get('user-agent', 'Unknown')}")
logger.info(f"[REQUEST {request_id}] Top K: {request.top_k}")
# Log base64 image details
base64_length = len(request.image_base64)
logger.info(f"[REQUEST {request_id}] Base64 image length: {base64_length} characters")
logger.info(f"[REQUEST {request_id}] Base64 prefix (first 100 chars): {request.image_base64[:100]}...")
# Save base64 string to file for debugging
base64_log_file = os.path.join(LOG_DIR, f"request_{request_id}_base64.txt")
with open(base64_log_file, 'w') as f:
f.write(request.image_base64)
logger.info(f"[REQUEST {request_id}] Base64 saved to: {base64_log_file}")
try:
# Decode and save the actual image
try:
image_data = base64.b64decode(request.image_base64)
image_file = os.path.join(IMAGES_LOG_DIR, f"request_{request_id}.png")
with open(image_file, 'wb') as f:
f.write(image_data)
logger.info(f"[REQUEST {request_id}] Decoded image saved to: {image_file}")
logger.info(f"[REQUEST {request_id}] Decoded image size: {len(image_data)} bytes")
except Exception as decode_error:
logger.warning(f"[REQUEST {request_id}] Failed to decode/save image: {decode_error}")
# Preprocess image from base64
logger.info(f"[REQUEST {request_id}] Preprocessing image...")
processed_image = preprocess_image_from_base64(request.image_base64)
logger.info(f"[REQUEST {request_id}] Preprocessed image shape: {processed_image.shape}")
# Make prediction
logger.info(f"[REQUEST {request_id}] Running model inference...")
predictions = classifier.predict(processed_image, top_k=request.top_k)
# Log predictions
logger.info(f"[REQUEST {request_id}] PREDICTIONS:")
for i, pred in enumerate(predictions, 1):
logger.info(f"[REQUEST {request_id}] {i}. {pred['class']}: {pred['confidence_percent']} (confidence: {pred['confidence']:.4f})")
# Save detailed request log as JSON
request_log = {
"request_id": request_id,
"timestamp": datetime.now().isoformat(),
"client_ip": http_request.client.host,
"client_port": http_request.client.port,
"user_agent": http_request.headers.get('user-agent', 'Unknown'),
"base64_length": base64_length,
"image_file": image_file if 'image_file' in locals() else None,
"top_k": request.top_k,
"predictions": predictions,
"success": True
}
json_log_file = os.path.join(LOG_DIR, f"request_{request_id}.json")
with open(json_log_file, 'w') as f:
json.dump(request_log, f, indent=2)
logger.info(f"[REQUEST {request_id}] Full request log saved to: {json_log_file}")
logger.info(f"[REQUEST {request_id}] ✓ Prediction completed successfully")
logger.info(f"="*80)
return PredictionResponse(
predictions=predictions,
success=True,
message=f"Prediction successful (Request ID: {request_id})"
)
except Exception as e:
logger.error(f"[REQUEST {request_id}] ✗ Prediction FAILED")
logger.error(f"[REQUEST {request_id}] Error: {str(e)}")
logger.error(f"[REQUEST {request_id}] Error type: {type(e).__name__}")
logger.info(f"="*80)
# Save error log
error_log = {
"request_id": request_id,
"timestamp": datetime.now().isoformat(),
"error": str(e),
"error_type": type(e).__name__,
"base64_length": base64_length,
"success": False
}
error_log_file = os.path.join(LOG_DIR, f"request_{request_id}_ERROR.json")
with open(error_log_file, 'w') as f:
json.dump(error_log, f, indent=2)
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
if __name__ == "__main__":
# Run the API server
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info"
)
|