issaennab commited on
Commit
d2a2955
·
1 Parent(s): 586dd8f

Deploy QuickDraw API with trained model and comprehensive logging

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.keras filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space Dockerfile for QuickDraw API
2
+ FROM python:3.10-slim
3
+
4
+ # Create user
5
+ RUN useradd -m -u 1000 user
6
+ USER user
7
+ ENV PATH="/home/user/.local/bin:$PATH"
8
+
9
+ WORKDIR /app
10
+
11
+ # Install system dependencies
12
+ USER root
13
+ RUN apt-get update && apt-get install -y \
14
+ libgomp1 \
15
+ curl \
16
+ && rm -rf /var/lib/apt/lists/*
17
+ USER user
18
+
19
+ # Copy requirements and install
20
+ COPY --chown=user ./requirements.txt requirements.txt
21
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
22
+
23
+ # Copy application files
24
+ COPY --chown=user . /app
25
+
26
+ # Create directories for logs
27
+ RUN mkdir -p api_logs/received_images
28
+
29
+ # Expose port 7860 (required by HF Spaces)
30
+ EXPOSE 7860
31
+
32
+ # Start the API on port 7860
33
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,84 @@
1
  ---
2
- title: Quickdraw Api
3
- emoji:
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: QuickDraw Sketch Recognition API
3
+ emoji: 🎨
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
  ---
10
 
11
+ # QuickDraw Sketch Recognition API
12
+
13
+ Real-time sketch recognition API for VR/AR applications. Recognizes 46 different hand-drawn objects using a CNN trained on Google's QuickDraw dataset.
14
+
15
+ ## 🎯 Try It Out
16
+
17
+ Once the Space is running, you can:
18
+
19
+ ### Test via Swagger UI
20
+ Visit the API docs at: `https://issa-ennab-quickdraw-api.hf.space/docs`
21
+
22
+ ### Test via cURL
23
+ ```bash
24
+ # Health check
25
+ curl https://issa-ennab-quickdraw-api.hf.space/health
26
+
27
+ # Get supported classes
28
+ curl https://issa-ennab-quickdraw-api.hf.space/classes
29
+
30
+ # Make a prediction (replace with your base64 image)
31
+ curl -X POST https://issa-ennab-quickdraw-api.hf.space/predict/base64 \
32
+ -H "Content-Type: application/json" \
33
+ -d '{"image_base64": "YOUR_BASE64_IMAGE", "top_k": 3}'
34
+ ```
35
+
36
+ ### Unity/VR Integration
37
+ ```csharp
38
+ private string apiUrl = "https://issa-ennab-quickdraw-api.hf.space/predict/base64";
39
+ ```
40
+
41
+ ## 📋 Supported Classes (46 total)
42
+
43
+ **Animals:** cat, dog, bird, fish, bear, butterfly, spider
44
+ **Buildings:** house, castle, barn, bridge, lighthouse, church
45
+ **Transportation:** car, airplane, bicycle, truck, train
46
+ **Nature:** tree, flower, sun, moon, cloud, mountain
47
+ **Objects:** apple, banana, book, chair, table, cup, umbrella
48
+ **Body Parts:** face, eye, hand, foot
49
+ **Shapes:** circle, triangle, square, star
50
+ **Tools:** sword, axe, hammer, key, crown
51
+ **Music:** guitar, piano
52
+
53
+ ## 🔧 API Endpoints
54
+
55
+ - `GET /` - API information
56
+ - `GET /health` - Health check
57
+ - `GET /classes` - List all supported classes
58
+ - `POST /predict` - Upload image file for prediction
59
+ - `POST /predict/base64` - Send base64 encoded image (recommended for VR)
60
+
61
+ ## 🎮 Perfect For
62
+
63
+ - VR/AR drawing applications
64
+ - Educational games
65
+ - Real-time sketch recognition
66
+ - Interactive art tools
67
+
68
+ ## 📊 Model Performance
69
+
70
+ - **Accuracy:** 84.89% on validation set
71
+ - **Inference Time:** ~50-80ms on CPU
72
+ - **Model Size:** 2.9 MB
73
+ - **Input:** 28x28 grayscale images
74
+
75
+ ## 📖 Full Documentation
76
+
77
+ [GitHub Repository](https://github.com/Beakal-23/Augmented-Reality--Image-Detector-Final-Project-)
78
+
79
+ ## 🚀 Built With
80
+
81
+ - FastAPI for the REST API
82
+ - TensorFlow/Keras for the CNN model
83
+ - Google QuickDraw dataset
84
+ - Docker for deployment
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for QuickDraw sketch recognition.
3
+ Exposes API endpoints for VR/AR applications to classify drawings.
4
+ """
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Request
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from typing import List, Optional
9
+ import uvicorn
10
+ import logging
11
+ import os
12
+ import base64
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ import json
16
+
17
+ from model import SketchClassifier
18
+ from utils import preprocess_image_from_bytes, preprocess_image_from_base64
19
+
20
+ # Configure comprehensive logging
21
+ LOG_DIR = "api_logs"
22
+ IMAGES_LOG_DIR = os.path.join(LOG_DIR, "received_images")
23
+ os.makedirs(LOG_DIR, exist_ok=True)
24
+ os.makedirs(IMAGES_LOG_DIR, exist_ok=True)
25
+
26
+ # Setup logging to both file and console
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
30
+ handlers=[
31
+ logging.FileHandler(os.path.join(LOG_DIR, 'api.log')),
32
+ logging.StreamHandler()
33
+ ]
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Create separate logger for request details
38
+ request_logger = logging.getLogger("requests")
39
+ request_handler = logging.FileHandler(os.path.join(LOG_DIR, 'requests_detailed.log'))
40
+ request_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
41
+ request_logger.addHandler(request_handler)
42
+ request_logger.setLevel(logging.INFO)
43
+
44
+ # Initialize FastAPI app
45
+ app = FastAPI(
46
+ title="QuickDraw Sketch Recognition API",
47
+ description="API for recognizing hand-drawn sketches (house, cat, dog, car) for VR/AR applications",
48
+ version="1.0.0"
49
+ )
50
+
51
+ # CORS middleware - adjust origins based on your VR application needs
52
+ app.add_middleware(
53
+ CORSMiddleware,
54
+ allow_origins=["*"], # In production, specify your VR app's origin
55
+ allow_credentials=True,
56
+ allow_methods=["*"],
57
+ allow_headers=["*"],
58
+ )
59
+
60
+ # Initialize model (singleton)
61
+ classifier = None
62
+
63
+
64
+ class PredictionRequest(BaseModel):
65
+ """Request model for base64 encoded image"""
66
+ image_base64: str
67
+ top_k: Optional[int] = 3
68
+
69
+
70
+ class PredictionResponse(BaseModel):
71
+ """Response model for predictions"""
72
+ predictions: List[dict]
73
+ success: bool
74
+ message: Optional[str] = None
75
+
76
+
77
+ @app.on_event("startup")
78
+ async def startup_event():
79
+ """Load the model on startup"""
80
+ global classifier
81
+ try:
82
+ logger.info("Loading QuickDraw model...")
83
+ classifier = SketchClassifier()
84
+ logger.info("Model loaded successfully!")
85
+ except Exception as e:
86
+ logger.error(f"Failed to load model: {e}")
87
+ raise
88
+
89
+
90
+ @app.get("/")
91
+ async def root():
92
+ """Root endpoint"""
93
+ return {
94
+ "message": "QuickDraw Sketch Recognition API",
95
+ "version": "1.0.0",
96
+ "endpoints": {
97
+ "/health": "Health check",
98
+ "/predict": "Predict from uploaded image file (POST)",
99
+ "/predict/base64": "Predict from base64 encoded image (POST)",
100
+ "/classes": "Get list of supported classes (GET)"
101
+ }
102
+ }
103
+
104
+
105
+ @app.get("/health")
106
+ async def health_check():
107
+ """Health check endpoint"""
108
+ model_loaded = classifier is not None
109
+ return {
110
+ "status": "healthy" if model_loaded else "unhealthy",
111
+ "model_loaded": model_loaded
112
+ }
113
+
114
+
115
+ @app.get("/classes")
116
+ async def get_classes():
117
+ """Get list of supported drawing classes"""
118
+ if classifier is None:
119
+ raise HTTPException(status_code=503, detail="Model not loaded")
120
+
121
+ return {
122
+ "classes": classifier.class_names,
123
+ "num_classes": len(classifier.class_names)
124
+ }
125
+
126
+
127
+ @app.post("/predict", response_model=PredictionResponse)
128
+ async def predict_from_file(
129
+ file: UploadFile = File(...),
130
+ top_k: int = 3,
131
+ http_request: Request = None
132
+ ):
133
+ """
134
+ Predict drawing class from uploaded image file.
135
+
136
+ Args:
137
+ file: Image file (PNG, JPG, etc.)
138
+ top_k: Number of top predictions to return (default: 3)
139
+ http_request: FastAPI request object for logging
140
+
141
+ Returns:
142
+ PredictionResponse with top predictions and confidence scores
143
+ """
144
+ if classifier is None:
145
+ raise HTTPException(status_code=503, detail="Model not loaded")
146
+
147
+ # Generate unique request ID
148
+ request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
149
+
150
+ logger.info(f"="*80)
151
+ logger.info(f"[FILE-REQUEST {request_id}] New file upload prediction")
152
+ logger.info(f"[FILE-REQUEST {request_id}] Filename: {file.filename}")
153
+ logger.info(f"[FILE-REQUEST {request_id}] Content-Type: {file.content_type}")
154
+ logger.info(f"[FILE-REQUEST {request_id}] Top K: {top_k}")
155
+
156
+ try:
157
+ # Read image bytes
158
+ image_bytes = await file.read()
159
+ logger.info(f"[FILE-REQUEST {request_id}] File size: {len(image_bytes)} bytes")
160
+
161
+ # Save uploaded file
162
+ uploaded_file = os.path.join(IMAGES_LOG_DIR, f"uploaded_{request_id}_{file.filename}")
163
+ with open(uploaded_file, 'wb') as f:
164
+ f.write(image_bytes)
165
+ logger.info(f"[FILE-REQUEST {request_id}] File saved to: {uploaded_file}")
166
+
167
+ # Preprocess image
168
+ logger.info(f"[FILE-REQUEST {request_id}] Preprocessing image...")
169
+ processed_image = preprocess_image_from_bytes(image_bytes)
170
+ logger.info(f"[FILE-REQUEST {request_id}] Preprocessed shape: {processed_image.shape}")
171
+
172
+ # Make prediction
173
+ logger.info(f"[FILE-REQUEST {request_id}] Running inference...")
174
+ predictions = classifier.predict(processed_image, top_k=top_k)
175
+
176
+ # Log predictions
177
+ logger.info(f"[FILE-REQUEST {request_id}] PREDICTIONS:")
178
+ for i, pred in enumerate(predictions, 1):
179
+ logger.info(f"[FILE-REQUEST {request_id}] {i}. {pred['class']}: {pred['confidence_percent']}")
180
+
181
+ logger.info(f"[FILE-REQUEST {request_id}] ✓ Success")
182
+ logger.info(f"="*80)
183
+
184
+ return PredictionResponse(
185
+ predictions=predictions,
186
+ success=True,
187
+ message=f"Prediction successful (Request ID: {request_id})"
188
+ )
189
+
190
+ except Exception as e:
191
+ logger.error(f"[FILE-REQUEST {request_id}] ✗ FAILED: {e}")
192
+ logger.info(f"="*80)
193
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
194
+
195
+
196
+ @app.post("/predict/base64", response_model=PredictionResponse)
197
+ async def predict_from_base64(request: PredictionRequest, http_request: Request):
198
+ """
199
+ Predict drawing class from base64 encoded image.
200
+ Ideal for VR/AR applications sending image data directly.
201
+
202
+ Args:
203
+ request: PredictionRequest containing base64 image and optional top_k
204
+ http_request: FastAPI request object for logging
205
+
206
+ Returns:
207
+ PredictionResponse with top predictions and confidence scores
208
+ """
209
+ if classifier is None:
210
+ raise HTTPException(status_code=503, detail="Model not loaded")
211
+
212
+ # Generate unique request ID
213
+ request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
214
+
215
+ # Log incoming request details
216
+ logger.info(f"="*80)
217
+ logger.info(f"[REQUEST {request_id}] New prediction request from VR")
218
+ logger.info(f"[REQUEST {request_id}] Client: {http_request.client.host}:{http_request.client.port}")
219
+ logger.info(f"[REQUEST {request_id}] User-Agent: {http_request.headers.get('user-agent', 'Unknown')}")
220
+ logger.info(f"[REQUEST {request_id}] Top K: {request.top_k}")
221
+
222
+ # Log base64 image details
223
+ base64_length = len(request.image_base64)
224
+ logger.info(f"[REQUEST {request_id}] Base64 image length: {base64_length} characters")
225
+ logger.info(f"[REQUEST {request_id}] Base64 prefix (first 100 chars): {request.image_base64[:100]}...")
226
+
227
+ # Save base64 string to file for debugging
228
+ base64_log_file = os.path.join(LOG_DIR, f"request_{request_id}_base64.txt")
229
+ with open(base64_log_file, 'w') as f:
230
+ f.write(request.image_base64)
231
+ logger.info(f"[REQUEST {request_id}] Base64 saved to: {base64_log_file}")
232
+
233
+ try:
234
+ # Decode and save the actual image
235
+ try:
236
+ image_data = base64.b64decode(request.image_base64)
237
+ image_file = os.path.join(IMAGES_LOG_DIR, f"request_{request_id}.png")
238
+ with open(image_file, 'wb') as f:
239
+ f.write(image_data)
240
+ logger.info(f"[REQUEST {request_id}] Decoded image saved to: {image_file}")
241
+ logger.info(f"[REQUEST {request_id}] Decoded image size: {len(image_data)} bytes")
242
+ except Exception as decode_error:
243
+ logger.warning(f"[REQUEST {request_id}] Failed to decode/save image: {decode_error}")
244
+
245
+ # Preprocess image from base64
246
+ logger.info(f"[REQUEST {request_id}] Preprocessing image...")
247
+ processed_image = preprocess_image_from_base64(request.image_base64)
248
+ logger.info(f"[REQUEST {request_id}] Preprocessed image shape: {processed_image.shape}")
249
+
250
+ # Make prediction
251
+ logger.info(f"[REQUEST {request_id}] Running model inference...")
252
+ predictions = classifier.predict(processed_image, top_k=request.top_k)
253
+
254
+ # Log predictions
255
+ logger.info(f"[REQUEST {request_id}] PREDICTIONS:")
256
+ for i, pred in enumerate(predictions, 1):
257
+ logger.info(f"[REQUEST {request_id}] {i}. {pred['class']}: {pred['confidence_percent']} (confidence: {pred['confidence']:.4f})")
258
+
259
+ # Save detailed request log as JSON
260
+ request_log = {
261
+ "request_id": request_id,
262
+ "timestamp": datetime.now().isoformat(),
263
+ "client_ip": http_request.client.host,
264
+ "client_port": http_request.client.port,
265
+ "user_agent": http_request.headers.get('user-agent', 'Unknown'),
266
+ "base64_length": base64_length,
267
+ "image_file": image_file if 'image_file' in locals() else None,
268
+ "top_k": request.top_k,
269
+ "predictions": predictions,
270
+ "success": True
271
+ }
272
+
273
+ json_log_file = os.path.join(LOG_DIR, f"request_{request_id}.json")
274
+ with open(json_log_file, 'w') as f:
275
+ json.dump(request_log, f, indent=2)
276
+ logger.info(f"[REQUEST {request_id}] Full request log saved to: {json_log_file}")
277
+
278
+ logger.info(f"[REQUEST {request_id}] ✓ Prediction completed successfully")
279
+ logger.info(f"="*80)
280
+
281
+ return PredictionResponse(
282
+ predictions=predictions,
283
+ success=True,
284
+ message=f"Prediction successful (Request ID: {request_id})"
285
+ )
286
+
287
+ except Exception as e:
288
+ logger.error(f"[REQUEST {request_id}] ✗ Prediction FAILED")
289
+ logger.error(f"[REQUEST {request_id}] Error: {str(e)}")
290
+ logger.error(f"[REQUEST {request_id}] Error type: {type(e).__name__}")
291
+ logger.info(f"="*80)
292
+
293
+ # Save error log
294
+ error_log = {
295
+ "request_id": request_id,
296
+ "timestamp": datetime.now().isoformat(),
297
+ "error": str(e),
298
+ "error_type": type(e).__name__,
299
+ "base64_length": base64_length,
300
+ "success": False
301
+ }
302
+ error_log_file = os.path.join(LOG_DIR, f"request_{request_id}_ERROR.json")
303
+ with open(error_log_file, 'w') as f:
304
+ json.dump(error_log, f, indent=2)
305
+
306
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
307
+
308
+
309
+ if __name__ == "__main__":
310
+ # Run the API server
311
+ uvicorn.run(
312
+ "main:app",
313
+ host="0.0.0.0",
314
+ port=8000,
315
+ reload=True,
316
+ log_level="info"
317
+ )
config.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the QuickDraw API.
3
+ Modify these settings based on your deployment needs.
4
+ """
5
+ import os
6
+ from typing import List
7
+
8
+
9
+ class Settings:
10
+ """Application settings"""
11
+
12
+ # API Settings
13
+ API_TITLE: str = "QuickDraw Sketch Recognition API"
14
+ API_VERSION: str = "1.0.0"
15
+ API_DESCRIPTION: str = "API for recognizing hand-drawn sketches for VR/AR applications"
16
+
17
+ # Server Settings
18
+ HOST: str = "0.0.0.0"
19
+ PORT: int = 8000
20
+ RELOAD: bool = False # Set to True for development
21
+
22
+ # CORS Settings
23
+ CORS_ORIGINS: List[str] = ["*"] # In production, specify allowed origins
24
+ CORS_ALLOW_CREDENTIALS: bool = True
25
+ CORS_ALLOW_METHODS: List[str] = ["*"]
26
+ CORS_ALLOW_HEADERS: List[str] = ["*"]
27
+
28
+ # Model Settings
29
+ MODEL_PATH: str = os.path.join("saved_models", "quickdraw_house_cat_dog_car.keras")
30
+ CLASS_NAMES: List[str] = [
31
+ # Animals (7)
32
+ "cat", "dog", "bird", "fish", "bear", "butterfly", "spider",
33
+ # Buildings & Structures (6)
34
+ "house", "castle", "barn", "bridge", "lighthouse", "church",
35
+ # Transportation (5)
36
+ "car", "airplane", "bicycle", "truck", "train",
37
+ # Nature (6)
38
+ "tree", "flower", "sun", "moon", "cloud", "mountain",
39
+ # Common Objects (7)
40
+ "apple", "banana", "book", "chair", "table", "cup", "umbrella",
41
+ # People & Body (4)
42
+ "face", "eye", "hand", "foot",
43
+ # Shapes (4)
44
+ "circle", "triangle", "square", "star",
45
+ # Tools & Items (5)
46
+ "sword", "axe", "hammer", "key", "crown",
47
+ # Musical Instruments (2)
48
+ "guitar", "piano"
49
+ ]
50
+
51
+ # Prediction Settings
52
+ DEFAULT_TOP_K: int = 3
53
+ CONFIDENCE_THRESHOLD: float = 0.5 # Minimum confidence for valid predictions
54
+
55
+ # Image Processing Settings
56
+ INPUT_IMAGE_SIZE: tuple = (28, 28)
57
+ GRAYSCALE: bool = True
58
+ NORMALIZE: bool = True # Normalize pixel values to [0, 1]
59
+
60
+ # Logging
61
+ LOG_LEVEL: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
62
+
63
+
64
+ # Create a singleton instance
65
+ settings = Settings()
model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model inference module for QuickDraw sketch classification.
3
+ Handles model loading and prediction logic.
4
+ """
5
+ import os
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from typing import List, Dict
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class SketchClassifier:
15
+ """QuickDraw sketch classifier"""
16
+
17
+ def __init__(self, model_path: str = None):
18
+ """
19
+ Initialize the classifier with a trained model.
20
+
21
+ Args:
22
+ model_path: Path to the trained model file. If None, uses default path.
23
+ """
24
+ # Extended class list matching Model-Training.py
25
+ self.class_names = [
26
+ # Animals
27
+ "cat", "dog", "bird", "fish", "bear", "butterfly", "bee", "spider",
28
+ # Buildings & Structures
29
+ "house", "castle", "barn", "bridge", "lighthouse", "church",
30
+ # Transportation
31
+ "car", "airplane", "bicycle", "boat", "train", "truck", "bus",
32
+ # Nature
33
+ "tree", "flower", "sun", "moon", "cloud", "mountain", "river",
34
+ # Common Objects
35
+ "apple", "banana", "book", "chair", "table", "cup", "umbrella",
36
+ # People & Body
37
+ "face", "eye", "hand", "foot",
38
+ # Shapes & Symbols
39
+ "circle", "triangle", "square", "star", "heart",
40
+ # Tools & Items
41
+ "sword", "axe", "hammer", "key", "crown"
42
+ ]
43
+
44
+ # Default model path
45
+ if model_path is None:
46
+ model_path = os.path.join("saved_models", "quickdraw_house_cat_dog_car.keras")
47
+
48
+ # Check if model exists
49
+ if not os.path.exists(model_path):
50
+ # Try .h5 format as fallback
51
+ h5_path = model_path.replace(".keras", ".h5")
52
+ if os.path.exists(h5_path):
53
+ model_path = h5_path
54
+ logger.info(f"Using H5 model format: {model_path}")
55
+ else:
56
+ raise FileNotFoundError(
57
+ f"Model file not found at {model_path}. "
58
+ "Please train the model first using Model-Training.py"
59
+ )
60
+
61
+ logger.info(f"Loading model from: {model_path}")
62
+ self.model = tf.keras.models.load_model(model_path)
63
+ logger.info("Model loaded successfully!")
64
+
65
+ # Verify input shape
66
+ self.input_shape = self.model.input_shape[1:] # (28, 28, 1)
67
+ logger.info(f"Model input shape: {self.input_shape}")
68
+
69
+ def predict(self, image: np.ndarray, top_k: int = 3) -> List[Dict[str, any]]:
70
+ """
71
+ Make prediction on a preprocessed image.
72
+
73
+ Args:
74
+ image: Preprocessed image array of shape (1, 28, 28, 1)
75
+ top_k: Number of top predictions to return
76
+
77
+ Returns:
78
+ List of dictionaries containing class names and confidence scores
79
+ """
80
+ # Validate input shape
81
+ if image.shape != (1, 28, 28, 1):
82
+ raise ValueError(
83
+ f"Expected input shape (1, 28, 28, 1), got {image.shape}. "
84
+ "Please preprocess the image first."
85
+ )
86
+
87
+ # Make prediction
88
+ predictions = self.model.predict(image, verbose=0)
89
+
90
+ # Get top k predictions
91
+ top_indices = np.argsort(predictions[0])[::-1][:top_k]
92
+
93
+ results = []
94
+ for idx in top_indices:
95
+ results.append({
96
+ "class": self.class_names[idx],
97
+ "confidence": float(predictions[0][idx]),
98
+ "confidence_percent": f"{predictions[0][idx] * 100:.2f}%"
99
+ })
100
+
101
+ return results
102
+
103
+ def predict_batch(self, images: np.ndarray, top_k: int = 3) -> List[List[Dict[str, any]]]:
104
+ """
105
+ Make predictions on a batch of preprocessed images.
106
+
107
+ Args:
108
+ images: Batch of preprocessed images of shape (N, 28, 28, 1)
109
+ top_k: Number of top predictions to return per image
110
+
111
+ Returns:
112
+ List of prediction results for each image
113
+ """
114
+ # Make predictions
115
+ predictions = self.model.predict(images, verbose=0)
116
+
117
+ results = []
118
+ for pred in predictions:
119
+ # Get top k predictions for this image
120
+ top_indices = np.argsort(pred)[::-1][:top_k]
121
+
122
+ image_results = []
123
+ for idx in top_indices:
124
+ image_results.append({
125
+ "class": self.class_names[idx],
126
+ "confidence": float(pred[idx]),
127
+ "confidence_percent": f"{pred[idx] * 100:.2f}%"
128
+ })
129
+
130
+ results.append(image_results)
131
+
132
+ return results
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # QuickDraw Sketch Recognition API
2
+ # Compatible with Python 3.10+ on Windows, macOS (Intel & Apple Silicon), and Linux
3
+
4
+ # Core dependencies
5
+ fastapi>=0.115.2
6
+ uvicorn[standard]>=0.24.0
7
+ pydantic>=2.7.4
8
+ python-multipart>=0.0.18
9
+
10
+ # ML/AI libraries
11
+ tensorflow>=2.15.0
12
+ numpy>=1.25.0,<2.0 # TensorFlow 2.15 requires numpy < 2.0
13
+ scikit-learn>=1.3.2
14
+ matplotlib>=3.8.2
15
+
16
+ # Image processing
17
+ Pillow>=10.1.0
18
+
19
+ # ONNX support (optional, for model export)
20
+ tf2onnx>=1.15.1
21
+ onnx>=1.15.0
22
+ onnxruntime>=1.16.3
23
+
24
+ # Development and testing
25
+ pytest>=7.4.3
26
+ httpx>=0.25.2
27
+ requests>=2.32.2
28
+
29
+ # Hugging Face integration
30
+ huggingface-hub>=0.20.0
saved_models/quickdraw_house_cat_dog_car.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24dd8c8b1b1e19b927d937f8fae3ba1507ce312ee35e4f3e015591a327e3edfe
3
+ size 3000896
saved_models/quickdraw_house_cat_dog_car.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa5ca71b085fb590fed2d5a550154f905b90516c98617e3e0c8f665ce2bd6590
3
+ size 2999536
saved_models/quickdraw_house_cat_dog_car.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c339e3d8798df6c473f15cb052e98f5bff92cc711e2ee4058f695b27f185ac6
3
+ size 989107
utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for image preprocessing.
3
+ Handles various input formats: bytes, base64, PIL images, etc.
4
+ """
5
+ import io
6
+ import base64
7
+ import numpy as np
8
+ from PIL import Image
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def preprocess_image_from_bytes(image_bytes: bytes) -> np.ndarray:
15
+ """
16
+ Preprocess image from raw bytes.
17
+
18
+ Args:
19
+ image_bytes: Raw image bytes (PNG, JPG, etc.)
20
+
21
+ Returns:
22
+ Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
23
+ """
24
+ try:
25
+ # Load image from bytes
26
+ image = Image.open(io.BytesIO(image_bytes))
27
+
28
+ # Convert to grayscale
29
+ image = image.convert('L')
30
+
31
+ # Resize to 28x28
32
+ image = image.resize((28, 28), Image.Resampling.LANCZOS)
33
+
34
+ # Convert to numpy array
35
+ image_array = np.array(image, dtype=np.float32)
36
+
37
+ # Normalize to [0, 1]
38
+ image_array = image_array / 255.0
39
+
40
+ # Reshape to (1, 28, 28, 1) for model input
41
+ image_array = image_array.reshape(1, 28, 28, 1)
42
+
43
+ return image_array
44
+
45
+ except Exception as e:
46
+ logger.error(f"Error preprocessing image from bytes: {e}")
47
+ raise ValueError(f"Failed to process image: {str(e)}")
48
+
49
+
50
+ def preprocess_image_from_base64(base64_string: str) -> np.ndarray:
51
+ """
52
+ Preprocess image from base64 encoded string.
53
+
54
+ Args:
55
+ base64_string: Base64 encoded image string (with or without data URI prefix)
56
+
57
+ Returns:
58
+ Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
59
+ """
60
+ try:
61
+ # Remove data URI prefix if present (e.g., "data:image/png;base64,")
62
+ if ',' in base64_string and base64_string.startswith('data:'):
63
+ base64_string = base64_string.split(',', 1)[1]
64
+
65
+ # Decode base64 to bytes
66
+ image_bytes = base64.b64decode(base64_string)
67
+
68
+ # Use the bytes preprocessing function
69
+ return preprocess_image_from_bytes(image_bytes)
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error preprocessing image from base64: {e}")
73
+ raise ValueError(f"Failed to process base64 image: {str(e)}")
74
+
75
+
76
+ def preprocess_image_from_array(image_array: np.ndarray) -> np.ndarray:
77
+ """
78
+ Preprocess image from numpy array.
79
+ Handles various input shapes and formats.
80
+
81
+ Args:
82
+ image_array: Numpy array representing an image
83
+
84
+ Returns:
85
+ Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
86
+ """
87
+ try:
88
+ # Convert to float32
89
+ image_array = image_array.astype(np.float32)
90
+
91
+ # Handle different input shapes
92
+ if len(image_array.shape) == 4: # (batch, height, width, channels)
93
+ # Take first image if batch
94
+ image_array = image_array[0]
95
+
96
+ if len(image_array.shape) == 3: # (height, width, channels)
97
+ # If RGB, convert to grayscale
98
+ if image_array.shape[2] == 3:
99
+ # Simple RGB to grayscale conversion
100
+ image_array = 0.299 * image_array[:, :, 0] + \
101
+ 0.587 * image_array[:, :, 1] + \
102
+ 0.114 * image_array[:, :, 2]
103
+ elif image_array.shape[2] == 1:
104
+ image_array = image_array.squeeze(-1)
105
+
106
+ # Now image_array should be 2D (height, width)
107
+ if len(image_array.shape) != 2:
108
+ raise ValueError(f"Cannot process image with shape {image_array.shape}")
109
+
110
+ # Resize if needed
111
+ if image_array.shape != (28, 28):
112
+ image_pil = Image.fromarray(image_array.astype(np.uint8))
113
+ image_pil = image_pil.resize((28, 28), Image.Resampling.LANCZOS)
114
+ image_array = np.array(image_pil, dtype=np.float32)
115
+
116
+ # Normalize to [0, 1] if not already
117
+ if image_array.max() > 1.0:
118
+ image_array = image_array / 255.0
119
+
120
+ # Reshape to (1, 28, 28, 1)
121
+ image_array = image_array.reshape(1, 28, 28, 1)
122
+
123
+ return image_array
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error preprocessing image from array: {e}")
127
+ raise ValueError(f"Failed to process image array: {str(e)}")
128
+
129
+
130
+ def preprocess_stroke_data(strokes: list, canvas_size: int = 256) -> np.ndarray:
131
+ """
132
+ Convert stroke data (list of coordinates) to a 28x28 image.
133
+ Useful if VR application sends raw drawing coordinates.
134
+
135
+ Args:
136
+ strokes: List of strokes, where each stroke is a list of (x, y) coordinates
137
+ Example: [[(x1, y1), (x2, y2), ...], [(x3, y3), ...]]
138
+ canvas_size: Size of the virtual canvas (default: 256x256)
139
+
140
+ Returns:
141
+ Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
142
+ """
143
+ try:
144
+ # Create a blank canvas
145
+ canvas = np.zeros((canvas_size, canvas_size), dtype=np.uint8)
146
+
147
+ # Draw strokes on canvas
148
+ for stroke in strokes:
149
+ if len(stroke) < 2:
150
+ continue
151
+
152
+ # Draw lines between consecutive points
153
+ for i in range(len(stroke) - 1):
154
+ x1, y1 = stroke[i]
155
+ x2, y2 = stroke[i + 1]
156
+
157
+ # Simple line drawing (Bresenham's algorithm would be better)
158
+ # For now, use a simple approximation
159
+ points = _interpolate_points(x1, y1, x2, y2)
160
+ for x, y in points:
161
+ if 0 <= x < canvas_size and 0 <= y < canvas_size:
162
+ canvas[int(y), int(x)] = 255
163
+
164
+ # Convert canvas to PIL Image for resizing
165
+ image = Image.fromarray(canvas)
166
+ image = image.resize((28, 28), Image.Resampling.LANCZOS)
167
+
168
+ # Convert to numpy array and normalize
169
+ image_array = np.array(image, dtype=np.float32) / 255.0
170
+
171
+ # Reshape to (1, 28, 28, 1)
172
+ image_array = image_array.reshape(1, 28, 28, 1)
173
+
174
+ return image_array
175
+
176
+ except Exception as e:
177
+ logger.error(f"Error preprocessing stroke data: {e}")
178
+ raise ValueError(f"Failed to process stroke data: {str(e)}")
179
+
180
+
181
+ def _interpolate_points(x1: float, y1: float, x2: float, y2: float, num_points: int = 10) -> list:
182
+ """
183
+ Interpolate points between two coordinates for smooth line drawing.
184
+
185
+ Args:
186
+ x1, y1: Start coordinates
187
+ x2, y2: End coordinates
188
+ num_points: Number of points to interpolate
189
+
190
+ Returns:
191
+ List of (x, y) coordinate tuples
192
+ """
193
+ points = []
194
+ for i in range(num_points + 1):
195
+ t = i / num_points
196
+ x = x1 + t * (x2 - x1)
197
+ y = y1 + t * (y2 - y1)
198
+ points.append((x, y))
199
+ return points