File size: 11,553 Bytes
d2a2955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
FastAPI application for QuickDraw sketch recognition.
Exposes API endpoints for VR/AR applications to classify drawings.
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
import logging
import os
import base64
from datetime import datetime
from pathlib import Path
import json

from model import SketchClassifier
from utils import preprocess_image_from_bytes, preprocess_image_from_base64

# Configure comprehensive logging
LOG_DIR = "api_logs"
IMAGES_LOG_DIR = os.path.join(LOG_DIR, "received_images")
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(IMAGES_LOG_DIR, exist_ok=True)

# Setup logging to both file and console
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(LOG_DIR, 'api.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Create separate logger for request details
request_logger = logging.getLogger("requests")
request_handler = logging.FileHandler(os.path.join(LOG_DIR, 'requests_detailed.log'))
request_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
request_logger.addHandler(request_handler)
request_logger.setLevel(logging.INFO)

# Initialize FastAPI app
app = FastAPI(
    title="QuickDraw Sketch Recognition API",
    description="API for recognizing hand-drawn sketches (house, cat, dog, car) for VR/AR applications",
    version="1.0.0"
)

# CORS middleware - adjust origins based on your VR application needs
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, specify your VR app's origin
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize model (singleton)
classifier = None


class PredictionRequest(BaseModel):
    """Request model for base64 encoded image"""
    image_base64: str
    top_k: Optional[int] = 3


class PredictionResponse(BaseModel):
    """Response model for predictions"""
    predictions: List[dict]
    success: bool
    message: Optional[str] = None


@app.on_event("startup")
async def startup_event():
    """Load the model on startup"""
    global classifier
    try:
        logger.info("Loading QuickDraw model...")
        classifier = SketchClassifier()
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise


@app.get("/")
async def root():
    """Root endpoint"""
    return {
        "message": "QuickDraw Sketch Recognition API",
        "version": "1.0.0",
        "endpoints": {
            "/health": "Health check",
            "/predict": "Predict from uploaded image file (POST)",
            "/predict/base64": "Predict from base64 encoded image (POST)",
            "/classes": "Get list of supported classes (GET)"
        }
    }


@app.get("/health")
async def health_check():
    """Health check endpoint"""
    model_loaded = classifier is not None
    return {
        "status": "healthy" if model_loaded else "unhealthy",
        "model_loaded": model_loaded
    }


@app.get("/classes")
async def get_classes():
    """Get list of supported drawing classes"""
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    return {
        "classes": classifier.class_names,
        "num_classes": len(classifier.class_names)
    }


@app.post("/predict", response_model=PredictionResponse)
async def predict_from_file(
    file: UploadFile = File(...),
    top_k: int = 3,
    http_request: Request = None
):
    """
    Predict drawing class from uploaded image file.
    
    Args:
        file: Image file (PNG, JPG, etc.)
        top_k: Number of top predictions to return (default: 3)
        http_request: FastAPI request object for logging
    
    Returns:
        PredictionResponse with top predictions and confidence scores
    """
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    # Generate unique request ID
    request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    
    logger.info(f"="*80)
    logger.info(f"[FILE-REQUEST {request_id}] New file upload prediction")
    logger.info(f"[FILE-REQUEST {request_id}] Filename: {file.filename}")
    logger.info(f"[FILE-REQUEST {request_id}] Content-Type: {file.content_type}")
    logger.info(f"[FILE-REQUEST {request_id}] Top K: {top_k}")
    
    try:
        # Read image bytes
        image_bytes = await file.read()
        logger.info(f"[FILE-REQUEST {request_id}] File size: {len(image_bytes)} bytes")
        
        # Save uploaded file
        uploaded_file = os.path.join(IMAGES_LOG_DIR, f"uploaded_{request_id}_{file.filename}")
        with open(uploaded_file, 'wb') as f:
            f.write(image_bytes)
        logger.info(f"[FILE-REQUEST {request_id}] File saved to: {uploaded_file}")
        
        # Preprocess image
        logger.info(f"[FILE-REQUEST {request_id}] Preprocessing image...")
        processed_image = preprocess_image_from_bytes(image_bytes)
        logger.info(f"[FILE-REQUEST {request_id}] Preprocessed shape: {processed_image.shape}")
        
        # Make prediction
        logger.info(f"[FILE-REQUEST {request_id}] Running inference...")
        predictions = classifier.predict(processed_image, top_k=top_k)
        
        # Log predictions
        logger.info(f"[FILE-REQUEST {request_id}] PREDICTIONS:")
        for i, pred in enumerate(predictions, 1):
            logger.info(f"[FILE-REQUEST {request_id}]   {i}. {pred['class']}: {pred['confidence_percent']}")
        
        logger.info(f"[FILE-REQUEST {request_id}] ✓ Success")
        logger.info(f"="*80)
        
        return PredictionResponse(
            predictions=predictions,
            success=True,
            message=f"Prediction successful (Request ID: {request_id})"
        )
    
    except Exception as e:
        logger.error(f"[FILE-REQUEST {request_id}] ✗ FAILED: {e}")
        logger.info(f"="*80)
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")


@app.post("/predict/base64", response_model=PredictionResponse)
async def predict_from_base64(request: PredictionRequest, http_request: Request):
    """
    Predict drawing class from base64 encoded image.
    Ideal for VR/AR applications sending image data directly.
    
    Args:
        request: PredictionRequest containing base64 image and optional top_k
        http_request: FastAPI request object for logging
    
    Returns:
        PredictionResponse with top predictions and confidence scores
    """
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    # Generate unique request ID
    request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    
    # Log incoming request details
    logger.info(f"="*80)
    logger.info(f"[REQUEST {request_id}] New prediction request from VR")
    logger.info(f"[REQUEST {request_id}] Client: {http_request.client.host}:{http_request.client.port}")
    logger.info(f"[REQUEST {request_id}] User-Agent: {http_request.headers.get('user-agent', 'Unknown')}")
    logger.info(f"[REQUEST {request_id}] Top K: {request.top_k}")
    
    # Log base64 image details
    base64_length = len(request.image_base64)
    logger.info(f"[REQUEST {request_id}] Base64 image length: {base64_length} characters")
    logger.info(f"[REQUEST {request_id}] Base64 prefix (first 100 chars): {request.image_base64[:100]}...")
    
    # Save base64 string to file for debugging
    base64_log_file = os.path.join(LOG_DIR, f"request_{request_id}_base64.txt")
    with open(base64_log_file, 'w') as f:
        f.write(request.image_base64)
    logger.info(f"[REQUEST {request_id}] Base64 saved to: {base64_log_file}")
    
    try:
        # Decode and save the actual image
        try:
            image_data = base64.b64decode(request.image_base64)
            image_file = os.path.join(IMAGES_LOG_DIR, f"request_{request_id}.png")
            with open(image_file, 'wb') as f:
                f.write(image_data)
            logger.info(f"[REQUEST {request_id}] Decoded image saved to: {image_file}")
            logger.info(f"[REQUEST {request_id}] Decoded image size: {len(image_data)} bytes")
        except Exception as decode_error:
            logger.warning(f"[REQUEST {request_id}] Failed to decode/save image: {decode_error}")
        
        # Preprocess image from base64
        logger.info(f"[REQUEST {request_id}] Preprocessing image...")
        processed_image = preprocess_image_from_base64(request.image_base64)
        logger.info(f"[REQUEST {request_id}] Preprocessed image shape: {processed_image.shape}")
        
        # Make prediction
        logger.info(f"[REQUEST {request_id}] Running model inference...")
        predictions = classifier.predict(processed_image, top_k=request.top_k)
        
        # Log predictions
        logger.info(f"[REQUEST {request_id}] PREDICTIONS:")
        for i, pred in enumerate(predictions, 1):
            logger.info(f"[REQUEST {request_id}]   {i}. {pred['class']}: {pred['confidence_percent']} (confidence: {pred['confidence']:.4f})")
        
        # Save detailed request log as JSON
        request_log = {
            "request_id": request_id,
            "timestamp": datetime.now().isoformat(),
            "client_ip": http_request.client.host,
            "client_port": http_request.client.port,
            "user_agent": http_request.headers.get('user-agent', 'Unknown'),
            "base64_length": base64_length,
            "image_file": image_file if 'image_file' in locals() else None,
            "top_k": request.top_k,
            "predictions": predictions,
            "success": True
        }
        
        json_log_file = os.path.join(LOG_DIR, f"request_{request_id}.json")
        with open(json_log_file, 'w') as f:
            json.dump(request_log, f, indent=2)
        logger.info(f"[REQUEST {request_id}] Full request log saved to: {json_log_file}")
        
        logger.info(f"[REQUEST {request_id}] ✓ Prediction completed successfully")
        logger.info(f"="*80)
        
        return PredictionResponse(
            predictions=predictions,
            success=True,
            message=f"Prediction successful (Request ID: {request_id})"
        )
    
    except Exception as e:
        logger.error(f"[REQUEST {request_id}] ✗ Prediction FAILED")
        logger.error(f"[REQUEST {request_id}] Error: {str(e)}")
        logger.error(f"[REQUEST {request_id}] Error type: {type(e).__name__}")
        logger.info(f"="*80)
        
        # Save error log
        error_log = {
            "request_id": request_id,
            "timestamp": datetime.now().isoformat(),
            "error": str(e),
            "error_type": type(e).__name__,
            "base64_length": base64_length,
            "success": False
        }
        error_log_file = os.path.join(LOG_DIR, f"request_{request_id}_ERROR.json")
        with open(error_log_file, 'w') as f:
            json.dump(error_log, f, indent=2)
        
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")


if __name__ == "__main__":
    # Run the API server
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=8000,
        reload=True,
        log_level="info"
    )