aditya-g07 commited on
Commit
b10b0ba
·
1 Parent(s): e7568be

Deploy RetinaFace face detection API with Gradio SDK

Browse files

- Added RetinaFace face detection models (MobileNet and ResNet)
- Implemented Gradio-based web interface and API endpoints
- Added utility modules for face detection processing
- Included model files: mobilenet0.25_Final.pth and Resnet50_Final.pth
- Added comprehensive documentation and deployment guides
- Added Thunkable integration examples and test scripts
- Ready for deployment on Hugging Face Spaces

DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RetinaFace Face Detection API
2
+
3
+ A Gradio-based face detection service using RetinaFace models (MobileNet and ResNet backbones) deployed on Hugging Face Spaces.
4
+
5
+ ## Features
6
+
7
+ - 🔥 **Dual Model Support**: MobileNet (fast) and ResNet (accurate) backbones
8
+ - 📱 **Thunkable Compatible**: API endpoints for mobile app integration
9
+ - ⚡ **Real-time Detection**: Web interface and API endpoints
10
+ - 🎨 **Interactive UI**: Gradio web interface for easy testing
11
+ - 🚀 **Serverless**: Deployed on Hugging Face Spaces for free
12
+
13
+ ## Web Interface
14
+
15
+ Access the interactive web interface at your Hugging Face Space URL:
16
+ - Image upload and detection
17
+ - Model selection (MobileNet/ResNet)
18
+ - Confidence threshold adjustment
19
+ - Real-time results visualization
20
+ - API testing interface
21
+
22
+ ## API Endpoints
23
+
24
+ ### 1. Gradio API Endpoint
25
+ ```
26
+ POST /api/predict
27
+ ```
28
+ Main API endpoint compatible with Thunkable and other applications.
29
+
30
+ **Request Body:**
31
+ ```json
32
+ {
33
+ "data": [
34
+ "base64_encoded_image_string",
35
+ "mobilenet",
36
+ 0.5,
37
+ 0.4
38
+ ]
39
+ }
40
+ ```
41
+
42
+ **Response:**
43
+ ```json
44
+ {
45
+ "data": [
46
+ {
47
+ "faces": [
48
+ {
49
+ "bbox": {"x1": 100, "y1": 120, "x2": 200, "y2": 220},
50
+ "confidence": 0.95,
51
+ "landmarks": {
52
+ "right_eye": [130, 150],
53
+ "left_eye": [170, 150],
54
+ "nose": [150, 170],
55
+ "right_mouth": [135, 190],
56
+ "left_mouth": [165, 190]
57
+ }
58
+ }
59
+ ],
60
+ "processing_time": 0.1,
61
+ "model_used": "mobilenet",
62
+ "total_faces": 1
63
+ }
64
+ ]
65
+ }
66
+ ```
67
+
68
+ ## Deployment Instructions
69
+
70
+ ### 1. Hugging Face Spaces Deployment
71
+
72
+ 1. **Create a new Space on Hugging Face:**
73
+ - Go to https://huggingface.co/spaces
74
+ - Click "Create new Space"
75
+ - Choose "Gradio" as SDK
76
+ - Set SDK version to 4.44.0
77
+ - Set visibility to "Public"
78
+
79
+ 2. **Upload your files:**
80
+ ```
81
+ ├── app.py # Main Gradio application
82
+ ├── requirements.txt # Python dependencies
83
+ ├── README.md # HF Spaces configuration
84
+ ├── mobilenet0.25_Final.pth # MobileNet model weights
85
+ ├── Resnet50_Final.pth # ResNet model weights
86
+ ├── models/
87
+ │ └── retinaface.py # RetinaFace model architecture
88
+ └── utils/
89
+ ├── box_utils.py # Bounding box utilities
90
+ ├── prior_box.py # Anchor box generation
91
+ └── py_cpu_nms.py # Non-maximum suppression
92
+ ```
93
+
94
+ 3. **Your Space will automatically build and deploy!**
95
+
96
+ ### 2. Local Testing
97
+
98
+ 1. **Install dependencies:**
99
+ ```bash
100
+ pip install -r requirements.txt
101
+ ```
102
+
103
+ 2. **Run locally:**
104
+ ```bash
105
+ python app.py
106
+ ```
107
+
108
+ 3. **Test the API:**
109
+ ```bash
110
+ python test_api.py
111
+ ```
112
+
113
+ 4. **Access the web interface:**
114
+ Open http://localhost:7860 in your browser
115
+
116
+ ## Thunkable Integration
117
+
118
+ ### 1. Web API Component Setup
119
+ ```
120
+ URL: https://your-username-retinaface-api.hf.space/api/predict
121
+ Method: POST
122
+ Headers: Content-Type: application/json
123
+ Body: {
124
+ "data": [
125
+ "{{base64_image}}",
126
+ "mobilenet",
127
+ 0.5,
128
+ 0.4
129
+ ]
130
+ }
131
+ ```
132
+
133
+ ### 2. Response Handling in Thunkable
134
+ ```
135
+ When Web API receives data:
136
+ Set app variable "apiResponse" to response body
137
+ Set app variable "detectionData" to get property "data" of apiResponse
138
+ Set app variable "faces" to get property "faces" of detectionData[0]
139
+ Set app variable "faceCount" to get property "total_faces" of detectionData[0]
140
+
141
+ If faceCount > 0:
142
+ For each face in faces:
143
+ // Process face data (bbox, confidence, landmarks)
144
+ ```
145
+
146
+ ### 3. Base64 Image Conversion
147
+ ```
148
+ // In Thunkable, convert camera image to base64
149
+ Set app variable "imageBase64" to
150
+ call CloudinaryAPI.convertToBase64
151
+ mediaDB = Camera1.Picture
152
+ ```
153
+
154
+ ## Model Performance
155
+
156
+ | Model | Speed | Accuracy | Use Case |
157
+ |-------|-------|----------|----------|
158
+ | MobileNet | Fast | Good | Real-time mobile apps |
159
+ | ResNet50 | Slower | High | High-accuracy applications |
160
+
161
+ ## API Testing
162
+
163
+ Use the built-in API testing interface in the Gradio app:
164
+ 1. Go to the "📊 API Testing" tab
165
+ 2. Paste your base64 encoded image
166
+ 3. Select model and parameters
167
+ 4. Click "🧪 Test API"
168
+ 5. View the JSON response
169
+
170
+ ## Error Handling
171
+
172
+ The API includes comprehensive error handling:
173
+ - Invalid image data validation
174
+ - Model loading verification
175
+ - Detailed error responses in JSON format
176
+
177
+ ## Advantages of Gradio SDK
178
+
179
+ ✅ **Web Interface**: Built-in UI for testing and demonstration
180
+ ✅ **API Endpoints**: Automatic API generation at `/api/predict`
181
+ ✅ **Easy Deployment**: No Docker configuration needed
182
+ ✅ **Real-time Testing**: Interactive interface for immediate feedback
183
+ ✅ **Documentation**: Built-in API documentation
184
+ ✅ **Mobile Friendly**: Responsive web interface
185
+
186
+ ## Limitations
187
+
188
+ - **File Size**: Max upload size determined by Hugging Face Spaces
189
+ - **Concurrent Requests**: Subject to Hugging Face Spaces limits
190
+ - **Cold Starts**: First request may take longer due to model loading
191
+ - **Processing Time**: Heavy models may timeout on free tier
192
+
193
+ ## Example Integration Code
194
+
195
+ ### JavaScript/Thunkable
196
+ ```javascript
197
+ const response = await fetch('https://your-space.hf.space/api/predict', {
198
+ method: 'POST',
199
+ headers: { 'Content-Type': 'application/json' },
200
+ body: JSON.stringify({
201
+ data: [base64Image, "mobilenet", 0.5, 0.4]
202
+ })
203
+ });
204
+
205
+ const result = await response.json();
206
+ const faces = result.data[0].faces;
207
+ ```
208
+
209
+ ### Python
210
+ ```python
211
+ import requests
212
+ import base64
213
+
214
+ # Convert image to base64
215
+ with open('image.jpg', 'rb') as f:
216
+ image_b64 = base64.b64encode(f.read()).decode()
217
+
218
+ # Make API call
219
+ response = requests.post(
220
+ 'https://your-space.hf.space/api/predict',
221
+ json={"data": [image_b64, "mobilenet", 0.5, 0.4]}
222
+ )
223
+
224
+ result = response.json()
225
+ faces = result["data"][0]["faces"]
226
+ ```
227
+
228
+ ## Support
229
+
230
+ For issues or questions:
231
+ 1. Check the web interface at your Space URL
232
+ 2. Test locally using the provided test script
233
+ 3. Use the built-in API testing tab in Gradio
234
+ 4. Verify model files are correctly uploaded
235
+
236
+ ## License
237
+
238
+ Apache 2.0
README.md CHANGED
@@ -1,14 +1,9 @@
1
- ---
2
- title: RetinaFace Face Detection
3
- emoji: 🔥
4
  colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: Retinaface face detection mobile0.25 + ResNet
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ title: RetinaFace Face Detection API
2
+ emoji: 😊
 
3
  colorFrom: blue
4
+ colorTo: red
5
  sdk: gradio
6
+ sdk_version: 4.44.0
7
  app_file: app.py
8
  pinned: false
9
+ license: apache-2.0
 
 
 
 
Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
THUNKABLE_EXAMPLES.md ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Thunkable Integration Examples for Gradio API
2
+
3
+ This file contains examples of how to integrate the RetinaFace Gradio API with Thunkable.
4
+
5
+ ## 1. Camera Capture and Face Detection
6
+
7
+ ### Blocks Setup:
8
+ ```
9
+ 1. Camera1 → TakePicture
10
+ 2. Convert image to base64
11
+ 3. Make API call to Gradio endpoint
12
+ 4. Process results
13
+ ```
14
+
15
+ ### Base64 Conversion Block:
16
+ ```
17
+ Set app variable "imageBase64" to
18
+ call CloudinaryAPI.convertToBase64
19
+ mediaDB = Camera1.Picture
20
+ ```
21
+
22
+ ### API Call Block:
23
+ ```
24
+ Web API1:
25
+ - URL: https://your-space-name.hf.space/api/predict
26
+ - Method: POST
27
+ - Headers: {"Content-Type": "application/json"}
28
+ - Body: {
29
+ "data": [
30
+ get app variable "imageBase64",
31
+ "mobilenet",
32
+ 0.5,
33
+ 0.4
34
+ ]
35
+ }
36
+ ```
37
+
38
+ ### Response Handling:
39
+ ```
40
+ When Web API1 receives data:
41
+ Set app variable "apiResponse" to responseBody
42
+ Set app variable "detectionData" to get property "data" of apiResponse
43
+ Set app variable "resultData" to get item 1 of list detectionData
44
+ Set app variable "faces" to get property "faces" of resultData
45
+ Set app variable "faceCount" to get property "total_faces" of resultData
46
+
47
+ If faceCount > 0:
48
+ For each item "face" in list "faces":
49
+ Set app variable "confidence" to get property "confidence" of object "face"
50
+ Set app variable "bbox" to get property "bbox" of object "face"
51
+
52
+ // Draw bounding box or show results
53
+ Set Label1.Text to join("Found face with confidence: ", confidence)
54
+ ```
55
+
56
+ ## 2. API Response Structure
57
+
58
+ ### Gradio API Response Format:
59
+ ```json
60
+ {
61
+ "data": [
62
+ {
63
+ "faces": [
64
+ {
65
+ "bbox": {"x1": 100, "y1": 120, "x2": 200, "y2": 220},
66
+ "confidence": 0.95,
67
+ "landmarks": {
68
+ "right_eye": [130, 150],
69
+ "left_eye": [170, 150],
70
+ "nose": [150, 170],
71
+ "right_mouth": [135, 190],
72
+ "left_mouth": [165, 190]
73
+ }
74
+ }
75
+ ],
76
+ "processing_time": 0.1,
77
+ "model_used": "mobilenet",
78
+ "total_faces": 1
79
+ }
80
+ ]
81
+ }
82
+ ```
83
+
84
+ ### Extracting Data in Thunkable:
85
+ ```
86
+ // Get the main detection result
87
+ Set app variable "result" to get item 1 of list (get property "data" of responseBody)
88
+
89
+ // Extract face information
90
+ Set app variable "faces" to get property "faces" of result
91
+ Set app variable "totalFaces" to get property "total_faces" of result
92
+ Set app variable "processingTime" to get property "processing_time" of result
93
+ Set app variable "modelUsed" to get property "model_used" of result
94
+
95
+ // For each face detected
96
+ For each item "face" in list "faces":
97
+ Set app variable "boundingBox" to get property "bbox" of face
98
+ Set app variable "confidence" to get property "confidence" of face
99
+ Set app variable "landmarks" to get property "landmarks" of face
100
+ ```
101
+
102
+ ## 3. Error Handling
103
+
104
+ ### Connection Error:
105
+ ```
106
+ When Web API1 has error:
107
+ Set Label_Error.Text to "Failed to connect to face detection service"
108
+ Set Label_Error.Visible to true
109
+ ```
110
+
111
+ ### API Error Response:
112
+ ```
113
+ When Web API1 receives data:
114
+ If response status ≠ 200:
115
+ Set Label_Error.Text to "API Error: Check your image format"
116
+ Else:
117
+ // Check for error in response data
118
+ Set app variable "result" to get item 1 of list (get property "data" of responseBody)
119
+ If get property "error" of result ≠ null:
120
+ Set Label_Error.Text to get property "error" of result
121
+ Else:
122
+ // Process successful response
123
+ ```
124
+
125
+ ## 4. Real-time Detection Loop
126
+
127
+ ### Continuous Detection:
128
+ ```
129
+ When Screen opens:
130
+ Set app variable "isDetecting" to true
131
+ Call function "startDetectionLoop"
132
+
133
+ Function startDetectionLoop:
134
+ While app variable "isDetecting" = true:
135
+ Camera1.TakePicture
136
+ Wait 1 second // Adjust for performance - Gradio may be slower than FastAPI
137
+
138
+ When Camera1.AfterPicture:
139
+ If app variable "isDetecting" = true:
140
+ Call API for detection
141
+ ```
142
+
143
+ ## 5. Performance Optimization
144
+
145
+ ### Image Compression:
146
+ ```
147
+ Before API call:
148
+ Set app variable "compressedImage" to
149
+ call ImageUtils.compress
150
+ image = Camera1.Picture
151
+ quality = 0.7 // Reduce file size for faster upload
152
+ maxWidth = 640 // Gradio handles smaller images better
153
+ ```
154
+
155
+ ### Model Selection for Performance:
156
+ ```
157
+ // For real-time applications, always use MobileNet
158
+ Set app variable "modelType" to "mobilenet"
159
+
160
+ // For high-accuracy single shots, use ResNet
161
+ Set app variable "modelType" to "resnet"
162
+ ```
163
+
164
+ ## 6. Complete API Integration Function
165
+
166
+ ### Thunkable Function: DetectFaces
167
+ ```
168
+ Function DetectFaces(imageToAnalyze, selectedModel, confidenceLevel):
169
+
170
+ // Convert image to base64
171
+ Set local variable "imageBase64" to
172
+ call CloudinaryAPI.convertToBase64
173
+ mediaDB = imageToAnalyze
174
+
175
+ // Prepare API request
176
+ Set local variable "requestData" to create object with:
177
+ "data" = create list with items:
178
+ - imageBase64
179
+ - selectedModel
180
+ - confidenceLevel
181
+ - 0.4 // NMS threshold
182
+
183
+ // Make API call
184
+ Call Web API1.POST with:
185
+ url = "https://your-space-name.hf.space/api/predict"
186
+ body = requestData
187
+ headers = create object with "Content-Type" = "application/json"
188
+
189
+ // Return to calling function
190
+ Return "API call initiated"
191
+ ```
192
+
193
+ ### Response Handler Function:
194
+ ```
195
+ Function ProcessDetectionResponse(responseBody):
196
+
197
+ // Extract main result
198
+ Set local variable "detectionResult" to get item 1 of list (get property "data" of responseBody)
199
+
200
+ // Check for errors
201
+ If get property "error" of detectionResult ≠ null:
202
+ Set Label_Status.Text to get property "error" of detectionResult
203
+ Return false
204
+
205
+ // Process successful detection
206
+ Set app variable "lastDetectionFaces" to get property "faces" of detectionResult
207
+ Set app variable "lastDetectionCount" to get property "total_faces" of detectionResult
208
+ Set app variable "lastProcessingTime" to get property "processing_time" of detectionResult
209
+
210
+ // Update UI
211
+ Set Label_FaceCount.Text to join("Faces detected: ", lastDetectionCount)
212
+ Set Label_ProcessingTime.Text to join("Processing time: ", lastProcessingTime, "s")
213
+
214
+ Return true
215
+ ```
216
+
217
+ ## 7. Advanced Features
218
+
219
+ ### Face Landmark Visualization:
220
+ ```
221
+ For each face in lastDetectionFaces:
222
+ Set local variable "landmarks" to get property "landmarks" of face
223
+
224
+ // Extract landmark coordinates
225
+ Set local variable "rightEye" to get property "right_eye" of landmarks
226
+ Set local variable "leftEye" to get property "left_eye" of landmarks
227
+ Set local variable "nose" to get property "nose" of landmarks
228
+ Set local variable "rightMouth" to get property "right_mouth" of landmarks
229
+ Set local variable "leftMouth" to get property "left_mouth" of landmarks
230
+
231
+ // Draw landmarks (if using drawing components)
232
+ Set Circle_RightEye.X to get item 1 of rightEye
233
+ Set Circle_RightEye.Y to get item 2 of rightEye
234
+ // ... repeat for other landmarks
235
+ ```
236
+
237
+ ### Confidence Filtering:
238
+ ```
239
+ Function FilterHighConfidenceFaces(allFaces, minConfidence):
240
+ Set local variable "filteredFaces" to create empty list
241
+
242
+ For each item "face" in list allFaces:
243
+ Set local variable "confidence" to get property "confidence" of face
244
+ If confidence ≥ minConfidence:
245
+ Add face to filteredFaces
246
+
247
+ Return filteredFaces
248
+ ```
249
+
250
+ ## 8. UI Components for Gradio Integration
251
+
252
+ ### Required Components:
253
+ ```
254
+ - Camera1 (for image capture)
255
+ - Button_Detect (trigger detection)
256
+ - Label_Status (show current status)
257
+ - Label_FaceCount (display number of faces)
258
+ - Label_ProcessingTime (show API response time)
259
+ - Label_Error (error messages)
260
+ - WebAPI1 (API communication)
261
+ - Dropdown_Model (model selection)
262
+ - Slider_Confidence (confidence threshold)
263
+ ```
264
+
265
+ ### Component Properties:
266
+ ```
267
+ Button_Detect:
268
+ - Text: "🔍 Detect Faces"
269
+ - Enabled: true when camera has image
270
+
271
+ Label_Status:
272
+ - Text: "Ready to detect faces"
273
+ - Font size: 16
274
+
275
+ Dropdown_Model:
276
+ - Options: ["mobilenet", "resnet"]
277
+ - Default: "mobilenet"
278
+
279
+ Slider_Confidence:
280
+ - Min: 0.1
281
+ - Max: 1.0
282
+ - Default: 0.5
283
+ - Step: 0.1
284
+ ```
285
+
286
+ ## 9. Testing Your Gradio Integration
287
+
288
+ ### Test Checklist:
289
+ - [ ] Camera permission granted
290
+ - [ ] Internet connection available
291
+ - [ ] Gradio API endpoint accessible (test in browser first)
292
+ - [ ] Base64 conversion working correctly
293
+ - [ ] Response parsing handles Gradio format
294
+ - [ ] Error handling for API failures
295
+ - [ ] UI updates with detection results
296
+
297
+ ### Debug Tips:
298
+ 1. Test Gradio web interface first at your Space URL
299
+ 2. Use the built-in "📊 API Testing" tab in Gradio
300
+ 3. Verify base64 encoding doesn't include data URL prefix
301
+ 4. Check that response format matches expected structure
302
+ 5. Monitor processing times (Gradio may be slower than FastAPI)
303
+
304
+ ## 10. Production Considerations
305
+
306
+ ### Performance:
307
+ - Gradio apps may have slightly higher latency than pure FastAPI
308
+ - Use MobileNet for real-time applications
309
+ - Consider image compression for faster uploads
310
+ - Implement proper loading indicators
311
+
312
+ ### Reliability:
313
+ - Handle Gradio app cold starts (first request may timeout)
314
+ - Implement retry logic for failed requests
315
+ - Cache successful results when appropriate
316
+ - Provide fallback options for offline scenarios
317
+
318
+ ### User Experience:
319
+ - Show clear loading states during API calls
320
+ - Provide informative error messages
321
+ - Allow users to switch between models
322
+ - Display confidence scores and processing times
323
+
324
+ ## 11. Sample Thunkable Blocks Layout
325
+
326
+ ### Main Detection Flow:
327
+ ```
328
+ When Button_Detect.Click:
329
+ → Set Label_Status.Text to "Capturing image..."
330
+ → Camera1.TakePicture
331
+
332
+ When Camera1.AfterPicture:
333
+ → Set Label_Status.Text to "Converting to base64..."
334
+ → Call CloudinaryAPI.convertToBase64
335
+
336
+ When CloudinaryAPI.GotBase64:
337
+ → Set Label_Status.Text to "Detecting faces..."
338
+ → Set app variable "imageB64" to base64Result
339
+ → Call function DetectFaces
340
+
341
+ When WebAPI1.GotText:
342
+ → Set Label_Status.Text to "Processing results..."
343
+ → Call function ProcessDetectionResponse
344
+ → Set Label_Status.Text to "Detection complete!"
345
+ ```
346
+
347
+ This comprehensive guide should help you successfully integrate your Gradio-based RetinaFace API with Thunkable!
app.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import base64
7
+ from typing import List, Dict, Any
8
+ import tempfile
9
+ import time
10
+ from PIL import Image, ImageDraw
11
+ import json
12
+
13
+ # Import RetinaFace model components
14
+ from models.retinaface import RetinaFace
15
+ from utils.prior_box import PriorBox
16
+ from utils.py_cpu_nms import py_cpu_nms
17
+ from utils.box_utils import decode, decode_landm
18
+
19
+ # Global variables for models
20
+ mobilenet_model = None
21
+ resnet_model = None
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ def load_models():
25
+ """Load both MobileNet and ResNet RetinaFace models"""
26
+ global mobilenet_model, resnet_model
27
+
28
+ try:
29
+ # Load MobileNet model
30
+ mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test')
31
+ mobilenet_model.load_state_dict(torch.load('mobilenet0.25_Final.pth', map_location=device))
32
+ mobilenet_model.eval()
33
+ mobilenet_model = mobilenet_model.to(device)
34
+
35
+ # Load ResNet model
36
+ resnet_model = RetinaFace(cfg=resnet_cfg, phase='test')
37
+ resnet_model.load_state_dict(torch.load('Resnet50_Final.pth', map_location=device))
38
+ resnet_model.eval()
39
+ resnet_model = resnet_model.to(device)
40
+
41
+ print("Models loaded successfully!")
42
+ return "✅ Models loaded successfully!"
43
+
44
+ except Exception as e:
45
+ error_msg = f"❌ Error loading models: {e}"
46
+ print(error_msg)
47
+ return error_msg
48
+
49
+ # Model configurations
50
+ mobilenet_cfg = {
51
+ 'name': 'mobilenet0.25',
52
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
53
+ 'steps': [8, 16, 32],
54
+ 'variance': [0.1, 0.2],
55
+ 'clip': False,
56
+ 'loc_weight': 2.0,
57
+ 'gpu_train': True,
58
+ 'batch_size': 32,
59
+ 'ngpu': 1,
60
+ 'epoch': 250,
61
+ 'decay1': 190,
62
+ 'decay2': 220,
63
+ 'image_size': 640,
64
+ 'pretrain': True,
65
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
66
+ 'in_channel': 32,
67
+ 'out_channel': 64
68
+ }
69
+
70
+ resnet_cfg = {
71
+ 'name': 'Resnet50',
72
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
73
+ 'steps': [8, 16, 32],
74
+ 'variance': [0.1, 0.2],
75
+ 'clip': False,
76
+ 'loc_weight': 2.0,
77
+ 'gpu_train': True,
78
+ 'batch_size': 24,
79
+ 'ngpu': 4,
80
+ 'epoch': 100,
81
+ 'decay1': 70,
82
+ 'decay2': 90,
83
+ 'image_size': 840,
84
+ 'pretrain': True,
85
+ 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
86
+ 'in_channel': 256,
87
+ 'out_channel': 256
88
+ }
89
+
90
+ def detect_faces_core(image, model, cfg, confidence_threshold=0.02, nms_threshold=0.4):
91
+ """Core face detection function"""
92
+ start_time = time.time()
93
+
94
+ # Preprocessing
95
+ img = np.float32(image)
96
+ im_height, im_width, _ = img.shape
97
+ scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
98
+ img -= (104, 117, 123)
99
+ img = img.transpose(2, 0, 1)
100
+ img = torch.from_numpy(img).unsqueeze(0)
101
+ img = img.to(device)
102
+ scale = scale.to(device)
103
+
104
+ # Forward pass
105
+ with torch.no_grad():
106
+ loc, conf, landms = model(img)
107
+
108
+ # Post-processing
109
+ priorbox = PriorBox(cfg, image_size=(im_height, im_width))
110
+ priors = priorbox.forward()
111
+ priors = priors.to(device)
112
+ prior_data = priors.data
113
+ boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
114
+ boxes = boxes * scale / 1
115
+ boxes = boxes.cpu().numpy()
116
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
117
+ landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
118
+ scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
119
+ img.shape[3], img.shape[2], img.shape[3], img.shape[2],
120
+ img.shape[3], img.shape[2]])
121
+ scale1 = scale1.to(device)
122
+ landms = landms * scale1 / 1
123
+ landms = landms.cpu().numpy()
124
+
125
+ # Ignore low scores
126
+ inds = np.where(scores > confidence_threshold)[0]
127
+ boxes = boxes[inds]
128
+ landms = landms[inds]
129
+ scores = scores[inds]
130
+
131
+ # Keep top-K before NMS
132
+ order = scores.argsort()[::-1][:5000]
133
+ boxes = boxes[order]
134
+ landms = landms[order]
135
+ scores = scores[order]
136
+
137
+ # Do NMS
138
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
139
+ keep = py_cpu_nms(dets, nms_threshold)
140
+ dets = dets[keep, :]
141
+ landms = landms[keep]
142
+
143
+ # Format results
144
+ faces = []
145
+ for i in range(dets.shape[0]):
146
+ if dets[i, 4] < confidence_threshold:
147
+ continue
148
+
149
+ face = {
150
+ "bbox": {
151
+ "x1": float(dets[i, 0]),
152
+ "y1": float(dets[i, 1]),
153
+ "x2": float(dets[i, 2]),
154
+ "y2": float(dets[i, 3])
155
+ },
156
+ "confidence": float(dets[i, 4]),
157
+ "landmarks": {
158
+ "right_eye": [float(landms[i, 0]), float(landms[i, 1])],
159
+ "left_eye": [float(landms[i, 2]), float(landms[i, 3])],
160
+ "nose": [float(landms[i, 4]), float(landms[i, 5])],
161
+ "right_mouth": [float(landms[i, 6]), float(landms[i, 7])],
162
+ "left_mouth": [float(landms[i, 8]), float(landms[i, 9])]
163
+ }
164
+ }
165
+ faces.append(face)
166
+
167
+ processing_time = time.time() - start_time
168
+ return faces, processing_time
169
+
170
+ def draw_faces_on_image(image, faces):
171
+ """Draw bounding boxes and landmarks on image"""
172
+ if isinstance(image, np.ndarray):
173
+ # Convert numpy array to PIL Image
174
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
175
+
176
+ draw = ImageDraw.Draw(image)
177
+
178
+ for face in faces:
179
+ bbox = face["bbox"]
180
+ confidence = face["confidence"]
181
+ landmarks = face["landmarks"]
182
+
183
+ # Draw bounding box
184
+ draw.rectangle([bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]],
185
+ outline="red", width=2)
186
+
187
+ # Draw confidence score
188
+ draw.text((bbox["x1"], bbox["y1"] - 15),
189
+ f'{confidence:.2f}', fill="red")
190
+
191
+ # Draw landmarks
192
+ for landmark_name, (x, y) in landmarks.items():
193
+ draw.ellipse([x-2, y-2, x+2, y+2], fill="blue")
194
+
195
+ return image
196
+
197
+ def gradio_detect_faces(image, model_type, confidence_threshold, nms_threshold):
198
+ """Gradio interface function for face detection"""
199
+ if mobilenet_model is None or resnet_model is None:
200
+ return None, "❌ Models not loaded. Please wait for models to load.", ""
201
+
202
+ try:
203
+ # Convert PIL to OpenCV format
204
+ if isinstance(image, Image.Image):
205
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
206
+
207
+ # Select model
208
+ if model_type.lower() == "resnet":
209
+ model = resnet_model
210
+ cfg = resnet_cfg
211
+ model_name = "ResNet50"
212
+ else:
213
+ model = mobilenet_model
214
+ cfg = mobilenet_cfg
215
+ model_name = "MobileNet"
216
+
217
+ # Detect faces
218
+ faces, processing_time = detect_faces_core(
219
+ image, model, cfg, confidence_threshold, nms_threshold
220
+ )
221
+
222
+ # Draw results on image
223
+ result_image = draw_faces_on_image(image.copy(), faces)
224
+
225
+ # Create results text
226
+ results_text = f"🎯 **Detection Results**\n"
227
+ results_text += f"📱 Model: {model_name}\n"
228
+ results_text += f"⏱️ Processing Time: {processing_time:.3f}s\n"
229
+ results_text += f"👥 Faces Detected: {len(faces)}\n\n"
230
+
231
+ for i, face in enumerate(faces):
232
+ results_text += f"**Face {i+1}:**\n"
233
+ results_text += f" Confidence: {face['confidence']:.3f}\n"
234
+ bbox = face['bbox']
235
+ results_text += f" Location: ({bbox['x1']:.0f}, {bbox['y1']:.0f}) - ({bbox['x2']:.0f}, {bbox['y2']:.0f})\n\n"
236
+
237
+ # Create JSON output for API use
238
+ json_output = {
239
+ "faces": faces,
240
+ "processing_time": processing_time,
241
+ "model_used": model_name.lower(),
242
+ "total_faces": len(faces)
243
+ }
244
+
245
+ return result_image, results_text, json.dumps(json_output, indent=2)
246
+
247
+ except Exception as e:
248
+ error_msg = f"❌ Detection failed: {str(e)}"
249
+ return None, error_msg, ""
250
+
251
+ def api_detect_live(image_base64, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
252
+ """API function for live detection (Thunkable compatible)"""
253
+ try:
254
+ # Decode base64 image
255
+ image_data = base64.b64decode(image_base64)
256
+ nparr = np.frombuffer(image_data, np.uint8)
257
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
258
+
259
+ if image is None:
260
+ return {"error": "Invalid image data"}
261
+
262
+ # Select model
263
+ if model_type.lower() == "resnet":
264
+ model = resnet_model
265
+ cfg = resnet_cfg
266
+ model_name = "resnet"
267
+ else:
268
+ model = mobilenet_model
269
+ cfg = mobilenet_cfg
270
+ model_name = "mobilenet"
271
+
272
+ if model is None:
273
+ return {"error": f"{model_name} model not loaded"}
274
+
275
+ # Detect faces
276
+ faces, processing_time = detect_faces_core(
277
+ image, model, cfg, confidence_threshold, nms_threshold
278
+ )
279
+
280
+ return {
281
+ "faces": faces,
282
+ "processing_time": processing_time,
283
+ "model_used": model_name,
284
+ "total_faces": len(faces)
285
+ }
286
+
287
+ except Exception as e:
288
+ return {"error": f"Detection failed: {str(e)}"}
289
+
290
+ # Load models on startup
291
+ print("Loading RetinaFace models...")
292
+ load_status = load_models()
293
+
294
+ # Create Gradio interface
295
+ with gr.Blocks(title="RetinaFace Face Detection API", theme=gr.themes.Soft()) as demo:
296
+ gr.Markdown("""
297
+ # 🔥 RetinaFace Face Detection API
298
+
299
+ **Real-time face detection using RetinaFace with MobileNet and ResNet backbones**
300
+
301
+ - 📱 **Mobile App Ready**: Compatible with Thunkable and other mobile frameworks
302
+ - ⚡ **Dual Models**: MobileNet (fast) and ResNet (accurate)
303
+ - 🎯 **High Accuracy**: Detects faces with bounding boxes and 5-point landmarks
304
+ - 🌐 **API Endpoints**: Use `/api/predict` for programmatic access
305
+ """)
306
+
307
+ with gr.Row():
308
+ gr.Markdown(f"**Status**: {load_status}")
309
+
310
+ with gr.Tab("🖼️ Image Detection"):
311
+ with gr.Row():
312
+ with gr.Column():
313
+ input_image = gr.Image(type="pil", label="Upload Image")
314
+ model_choice = gr.Dropdown(
315
+ choices=["mobilenet", "resnet"],
316
+ value="mobilenet",
317
+ label="Model Type"
318
+ )
319
+ confidence_slider = gr.Slider(
320
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
321
+ label="Confidence Threshold"
322
+ )
323
+ nms_slider = gr.Slider(
324
+ minimum=0.1, maximum=1.0, value=0.4, step=0.1,
325
+ label="NMS Threshold"
326
+ )
327
+ detect_btn = gr.Button("🔍 Detect Faces", variant="primary")
328
+
329
+ with gr.Column():
330
+ output_image = gr.Image(label="Detection Results")
331
+ results_text = gr.Markdown(label="Results")
332
+
333
+ detect_btn.click(
334
+ fn=gradio_detect_faces,
335
+ inputs=[input_image, model_choice, confidence_slider, nms_slider],
336
+ outputs=[output_image, results_text]
337
+ )
338
+
339
+ with gr.Tab("🔗 API Documentation"):
340
+ gr.Markdown("""
341
+ ## API Endpoints for Thunkable Integration
342
+
343
+ ### 1. Live Detection Endpoint
344
+ ```
345
+ POST /api/predict
346
+ ```
347
+
348
+ **Request Body (JSON):**
349
+ ```json
350
+ {
351
+ "data": [
352
+ "base64_encoded_image_string",
353
+ "mobilenet",
354
+ 0.5,
355
+ 0.4
356
+ ]
357
+ }
358
+ ```
359
+
360
+ **Response:**
361
+ ```json
362
+ {
363
+ "data": [
364
+ {
365
+ "faces": [...],
366
+ "processing_time": 0.1,
367
+ "model_used": "mobilenet",
368
+ "total_faces": 2
369
+ }
370
+ ]
371
+ }
372
+ ```
373
+
374
+ ### 2. Thunkable Integration Example
375
+
376
+ **Web API Component Setup:**
377
+ - URL: `https://your-space-name.hf.space/api/predict`
378
+ - Method: `POST`
379
+ - Headers: `Content-Type: application/json`
380
+ - Body:
381
+ ```json
382
+ {
383
+ "data": [
384
+ "{{base64_image}}",
385
+ "mobilenet",
386
+ 0.5,
387
+ 0.4
388
+ ]
389
+ }
390
+ ```
391
+
392
+ ### 3. Model Performance
393
+
394
+ | Model | Speed | Accuracy | Best For |
395
+ |-------|-------|----------|----------|
396
+ | MobileNet | ⚡ Fast | 🎯 Good | Real-time mobile apps |
397
+ | ResNet50 | 🐌 Slower | 🎯🎯 High | High-accuracy applications |
398
+
399
+ ### 4. Response Format
400
+
401
+ Each detected face includes:
402
+ - **bbox**: Bounding box coordinates (x1, y1, x2, y2)
403
+ - **confidence**: Detection confidence score (0-1)
404
+ - **landmarks**: 5-point facial landmarks (eyes, nose, mouth corners)
405
+ """)
406
+
407
+ with gr.Tab("📊 API Testing"):
408
+ gr.Markdown("### Test the API with base64 encoded images")
409
+
410
+ with gr.Row():
411
+ with gr.Column():
412
+ test_image_b64 = gr.Textbox(
413
+ label="Base64 Encoded Image",
414
+ placeholder="Paste base64 encoded image here...",
415
+ lines=3
416
+ )
417
+ test_model = gr.Dropdown(
418
+ choices=["mobilenet", "resnet"],
419
+ value="mobilenet",
420
+ label="Model"
421
+ )
422
+ test_conf = gr.Number(value=0.5, label="Confidence")
423
+ test_nms = gr.Number(value=0.4, label="NMS Threshold")
424
+ test_btn = gr.Button("🧪 Test API", variant="secondary")
425
+
426
+ with gr.Column():
427
+ api_output = gr.JSON(label="API Response")
428
+
429
+ def test_api_function(image_b64, model, conf, nms):
430
+ if not image_b64.strip():
431
+ return {"error": "Please provide base64 encoded image"}
432
+
433
+ # Remove data URL prefix if present
434
+ if image_b64.startswith('data:image'):
435
+ image_b64 = image_b64.split(',')[1]
436
+
437
+ result = api_detect_live(image_b64, model, conf, nms)
438
+ return result
439
+
440
+ test_btn.click(
441
+ fn=test_api_function,
442
+ inputs=[test_image_b64, test_model, test_conf, test_nms],
443
+ outputs=[api_output]
444
+ )
445
+
446
+ # Custom API function for external calls
447
+ def predict_api(image_base64, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
448
+ """API prediction function that matches Gradio's expected format"""
449
+ result = api_detect_live(image_base64, model_type, confidence_threshold, nms_threshold)
450
+ return result
451
+
452
+ # Launch the app
453
+ if __name__ == "__main__":
454
+ demo.launch(
455
+ server_name="0.0.0.0",
456
+ server_port=7860,
457
+ share=False
458
+ )
mobilenet0.25_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2979b33ffafda5d74b6948cd7a5b9a7a62f62b949cef24e95fd15d2883a65220
3
+ size 1789735
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Initialize empty __init__.py files for proper module imports
3
+ """
models/retinaface.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
+ from typing import Dict
6
+ import math
7
+
8
+ def conv_bn(inp, oup, stride=1, leaky=0):
9
+ return nn.Sequential(
10
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11
+ nn.BatchNorm2d(oup),
12
+ nn.LeakyReLU(negative_slope=leaky, inplace=True)
13
+ )
14
+
15
+ def conv_bn_no_relu(inp, oup, stride):
16
+ return nn.Sequential(
17
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
18
+ nn.BatchNorm2d(oup),
19
+ )
20
+
21
+ def conv_bn1X1(inp, oup, stride, leaky=0):
22
+ return nn.Sequential(
23
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
24
+ nn.BatchNorm2d(oup),
25
+ nn.LeakyReLU(negative_slope=leaky, inplace=True)
26
+ )
27
+
28
+ def conv_dw(inp, oup, stride, leaky=0.1):
29
+ return nn.Sequential(
30
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
31
+ nn.BatchNorm2d(inp),
32
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
33
+
34
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
35
+ nn.BatchNorm2d(oup),
36
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
37
+ )
38
+
39
+ class SSH(nn.Module):
40
+ def __init__(self, in_channel, out_channel):
41
+ super(SSH, self).__init__()
42
+ assert out_channel % 4 == 0
43
+ leaky = 0
44
+ if (out_channel <= 64):
45
+ leaky = 0.1
46
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
47
+
48
+ self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
49
+ self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
50
+
51
+ self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
52
+ self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
53
+
54
+ def forward(self, input):
55
+ conv3X3 = self.conv3X3(input)
56
+
57
+ conv5X5_1 = self.conv5X5_1(input)
58
+ conv5X5 = self.conv5X5_2(conv5X5_1)
59
+
60
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
61
+ conv7X7 = self.conv7x7_3(conv7X7_2)
62
+
63
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
64
+ out = F.relu(out)
65
+ return out
66
+
67
+ class FPN(nn.Module):
68
+ def __init__(self,in_channels_list,out_channels):
69
+ super(FPN,self).__init__()
70
+ leaky = 0
71
+ if (out_channels <= 64):
72
+ leaky = 0.1
73
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
74
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
75
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
76
+
77
+ self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
78
+ self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
79
+
80
+ def forward(self, input):
81
+ # names = list(input.keys())
82
+ input = list(input.values())
83
+
84
+ output1 = self.output1(input[0])
85
+ output2 = self.output2(input[1])
86
+ output3 = self.output3(input[2])
87
+
88
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
89
+ output2 = output2 + up3
90
+ output2 = self.merge2(output2)
91
+
92
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
93
+ output1 = output1 + up2
94
+ output1 = self.merge1(output1)
95
+
96
+ out = [output1, output2, output3]
97
+ return out
98
+
99
+ class MobileNetV1(nn.Module):
100
+ def __init__(self):
101
+ super(MobileNetV1, self).__init__()
102
+ self.stage1 = nn.Sequential(
103
+ conv_bn(3, 8, 2, leaky = 0.1), # 3
104
+ conv_dw(8, 16, 1), # 7
105
+ conv_dw(16, 32, 2), # 11
106
+ conv_dw(32, 32, 1), # 19
107
+ conv_dw(32, 64, 2), # 27
108
+ conv_dw(64, 64, 1), # 43
109
+ )
110
+ self.stage2 = nn.Sequential(
111
+ conv_dw(64, 128, 2), # 43 + 16 = 59
112
+ conv_dw(128, 128, 1), # 59 + 32 = 91
113
+ conv_dw(128, 128, 1), # 91 + 32 = 123
114
+ conv_dw(128, 128, 1), # 123 + 32 = 155
115
+ conv_dw(128, 128, 1), # 155 + 32 = 187
116
+ conv_dw(128, 128, 1), # 187 + 32 = 219
117
+ )
118
+ self.stage3 = nn.Sequential(
119
+ conv_dw(128, 256, 2), # 219 + 32 = 251
120
+ conv_dw(256, 256, 1), # 251 + 64 = 315
121
+ )
122
+ self.avg = nn.AdaptiveAvgPool2d((1,1))
123
+ self.fc = nn.Linear(256, 1000)
124
+
125
+ def forward(self, x):
126
+ x = self.stage1(x)
127
+ x = self.stage2(x)
128
+ x = self.stage3(x)
129
+ x = self.avg(x)
130
+ # x = self.model(x)
131
+ x = x.view(-1, 256)
132
+ x = self.fc(x)
133
+ return x
134
+
135
+ class ClassHead(nn.Module):
136
+ def __init__(self,inchannels=512,num_anchors=3):
137
+ super(ClassHead,self).__init__()
138
+ self.num_anchors = num_anchors
139
+ self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
140
+
141
+ def forward(self,x):
142
+ out = self.conv1x1(x)
143
+ out = out.permute(0,2,3,1).contiguous()
144
+
145
+ return out.view(out.shape[0], -1, 2)
146
+
147
+ class BboxHead(nn.Module):
148
+ def __init__(self,inchannels=512,num_anchors=3):
149
+ super(BboxHead,self).__init__()
150
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
151
+
152
+ def forward(self,x):
153
+ out = self.conv1x1(x)
154
+ out = out.permute(0,2,3,1).contiguous()
155
+
156
+ return out.view(out.shape[0], -1, 4)
157
+
158
+ class LandmarkHead(nn.Module):
159
+ def __init__(self,inchannels=512,num_anchors=3):
160
+ super(LandmarkHead,self).__init__()
161
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
162
+
163
+ def forward(self,x):
164
+ out = self.conv1x1(x)
165
+ out = out.permute(0,2,3,1).contiguous()
166
+
167
+ return out.view(out.shape[0], -1, 10)
168
+
169
+ class RetinaFace(nn.Module):
170
+ def __init__(self, cfg = None, phase = 'train'):
171
+ """
172
+ :param cfg: Network related settings.
173
+ :param phase: train or test.
174
+ """
175
+ super(RetinaFace,self).__init__()
176
+ self.phase = phase
177
+ backbone = None
178
+ if cfg['name'] == 'mobilenet0.25':
179
+ backbone = MobileNetV1()
180
+ if cfg['pretrain']:
181
+ checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
182
+ from collections import OrderedDict
183
+ new_state_dict = OrderedDict()
184
+ for k, v in checkpoint['state_dict'].items():
185
+ name = k[7:] # remove module.
186
+ new_state_dict[name] = v
187
+ # load params
188
+ backbone.load_state_dict(new_state_dict)
189
+ elif cfg['name'] == 'Resnet50':
190
+ import torchvision.models as models
191
+ backbone = models.resnet50(pretrained=cfg['pretrain'])
192
+
193
+ if cfg['name'] == 'Resnet50':
194
+ from torchvision.models._utils import IntermediateLayerGetter
195
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
196
+ else:
197
+ self.body = backbone
198
+
199
+ in_channels_stage2 = cfg['in_channel']
200
+ in_channels_list = [
201
+ in_channels_stage2 * 2,
202
+ in_channels_stage2 * 4,
203
+ in_channels_stage2 * 8,
204
+ ]
205
+ out_channels = cfg['out_channel']
206
+ self.fpn = FPN(in_channels_list,out_channels)
207
+ self.ssh1 = SSH(out_channels, out_channels)
208
+ self.ssh2 = SSH(out_channels, out_channels)
209
+ self.ssh3 = SSH(out_channels, out_channels)
210
+
211
+ self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
212
+ self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
213
+ self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
214
+
215
+ def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
216
+ classhead = nn.ModuleList()
217
+ for i in range(fpn_num):
218
+ classhead.append(ClassHead(inchannels,anchor_num))
219
+ return classhead
220
+
221
+ def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
222
+ bboxhead = nn.ModuleList()
223
+ for i in range(fpn_num):
224
+ bboxhead.append(BboxHead(inchannels,anchor_num))
225
+ return bboxhead
226
+
227
+ def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
228
+ landmarkhead = nn.ModuleList()
229
+ for i in range(fpn_num):
230
+ landmarkhead.append(LandmarkHead(inchannels,anchor_num))
231
+ return landmarkhead
232
+
233
+ def forward(self,inputs):
234
+ out = self.body(inputs)
235
+
236
+ # FPN
237
+ fpn = self.fpn(out)
238
+
239
+ # SSH
240
+ feature1 = self.ssh1(fpn[0])
241
+ feature2 = self.ssh2(fpn[1])
242
+ feature3 = self.ssh3(fpn[2])
243
+ features = [feature1, feature2, feature3]
244
+
245
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
246
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
247
+ ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
248
+
249
+ if self.phase == 'train':
250
+ output = (bbox_regressions, classifications, ldm_regressions)
251
+ else:
252
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
253
+ return output
254
+
255
+ # Utils for ResNet backbone
256
+ class _utils_resnet:
257
+ class IntermediateLayerGetter(nn.ModuleDict):
258
+ """
259
+ Module wrapper that returns intermediate layers from a model
260
+
261
+ It has a strong assumption that the modules have been registered
262
+ into the model in the same order as they are used.
263
+ This means that one should **not** reuse the same nn.Module
264
+ twice in the forward if you want this to work.
265
+
266
+ Additionally, it is only able to query submodules that are directly
267
+ assigned to the model. So if `model` is passed, `model.feature1` can
268
+ be returned, but not `model.feature1.layer2`.
269
+
270
+ Arguments:
271
+ model (nn.Module): model on which we will extract the features
272
+ return_layers (Dict[name, new_name]): a dict containing the names
273
+ of the modules for which the activations will be returned as
274
+ the key of the dict, and the value of the dict is the name
275
+ of the returned activation (which the user can specify).
276
+
277
+ Examples::
278
+
279
+ >>> m = torchvision.models.resnet18(pretrained=True)
280
+ >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
281
+ >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
282
+ >>> {'layer1': 'feat1', 'layer3': 'feat2'})
283
+ >>> out = new_m(x)
284
+ >>> print([(k, v.shape) for k, v in out.items()])
285
+ >>> [('feat1', torch.Size([1, 64, 56, 56])),
286
+ >>> ('feat2', torch.Size([1, 256, 14, 14]))]
287
+ """
288
+ _version = 2
289
+ __annotations__ = {
290
+ "return_layers": Dict[str, str],
291
+ }
292
+
293
+ def __init__(self, model, return_layers):
294
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
295
+ raise ValueError("return_layers are not present in model")
296
+ orig_return_layers = return_layers
297
+ return_layers = {str(k): str(v) for k, v in return_layers.items()}
298
+ layers = OrderedDict()
299
+ for name, module in model.named_children():
300
+ layers[name] = module
301
+ if name in return_layers:
302
+ del return_layers[name]
303
+ if not return_layers:
304
+ break
305
+
306
+ super(_utils_resnet.IntermediateLayerGetter, self).__init__(layers)
307
+ self.return_layers = orig_return_layers
308
+
309
+ def forward(self, x):
310
+ result = OrderedDict()
311
+ for name, module in self.items():
312
+ x = module(x)
313
+ if name in self.return_layers:
314
+ out_name = self.return_layers[name]
315
+ result[out_name] = x
316
+ return result
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ opencv-python==4.8.1.78
5
+ numpy==1.24.3
6
+ Pillow==10.0.1
7
+ fastapi==0.104.1
8
+ uvicorn[standard]==0.24.0
9
+ python-multipart==0.0.6
10
+ pydantic==2.4.2
start.bat ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Starting RetinaFace Gradio API...
3
+
4
+ REM Check if model files exist
5
+ if not exist "mobilenet0.25_Final.pth" (
6
+ echo Warning: mobilenet0.25_Final.pth not found!
7
+ )
8
+
9
+ if not exist "Resnet50_Final.pth" (
10
+ echo Warning: Resnet50_Final.pth not found!
11
+ )
12
+
13
+ REM Start the Gradio app
14
+ python app.py
15
+
16
+ pause
start.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Starting RetinaFace Face Detection API..."
4
+
5
+ # Check if model files exist
6
+ if [ ! -f "mobilenet0.25_Final.pth" ]; then
7
+ echo "Warning: mobilenet0.25_Final.pth not found!"
8
+ fi
9
+
10
+ if [ ! -f "Resnet50_Final.pth" ]; then
11
+ echo "Warning: Resnet50_Final.pth not found!"
12
+ fi
13
+
14
+ # Start the FastAPI server
15
+ uvicorn app:app --host 0.0.0.0 --port 7860 --reload
test_api.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import base64
3
+ import json
4
+
5
+ def test_gradio_api():
6
+ """Test the Gradio /api/predict endpoint"""
7
+ # You would replace this with actual base64 encoded image data
8
+ sample_image_path = "test_image.jpg" # Replace with your test image
9
+
10
+ try:
11
+ with open(sample_image_path, "rb") as image_file:
12
+ image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
13
+ except FileNotFoundError:
14
+ print("Please add a test_image.jpg file to test the API")
15
+ return
16
+
17
+ url = "http://localhost:7860/api/predict"
18
+
19
+ payload = {
20
+ "data": [
21
+ image_base64,
22
+ "mobilenet",
23
+ 0.5,
24
+ 0.4
25
+ ]
26
+ }
27
+
28
+ response = requests.post(url, json=payload)
29
+
30
+ if response.status_code == 200:
31
+ result = response.json()
32
+ print("Success!")
33
+ print(f"API Response: {json.dumps(result, indent=2)}")
34
+
35
+ # Extract the actual detection data
36
+ if "data" in result and len(result["data"]) > 0:
37
+ detection_data = result["data"][0]
38
+ if "faces" in detection_data:
39
+ print(f"Detected {len(detection_data['faces'])} faces")
40
+ print(f"Processing time: {detection_data.get('processing_time', 'N/A'):.3f} seconds")
41
+ print(f"Model used: {detection_data.get('model_used', 'N/A')}")
42
+
43
+ for i, face in enumerate(detection_data['faces']):
44
+ print(f"Face {i+1}:")
45
+ print(f" Confidence: {face['confidence']:.3f}")
46
+ print(f" Bounding box: {face['bbox']}")
47
+ print(f" Landmarks: {face['landmarks']}")
48
+ else:
49
+ print("No face detection data in response")
50
+ else:
51
+ print("Unexpected response format")
52
+ else:
53
+ print(f"Error: {response.status_code}")
54
+ print(response.text)
55
+
56
+ def test_health_check():
57
+ """Test the Gradio app health"""
58
+ url = "http://localhost:7860/"
59
+
60
+ response = requests.get(url)
61
+
62
+ if response.status_code == 200:
63
+ print("Gradio app is running!")
64
+ print("You can access the web interface at: http://localhost:7860")
65
+ else:
66
+ print(f"Health check failed: {response.status_code}")
67
+
68
+ def test_direct_api_call():
69
+ """Test direct API call format that Thunkable would use"""
70
+ sample_image_path = "test_image.jpg" # Replace with your test image
71
+
72
+ try:
73
+ with open(sample_image_path, "rb") as image_file:
74
+ image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
75
+ except FileNotFoundError:
76
+ print("Please add a test_image.jpg file to test the API")
77
+ return
78
+
79
+ url = "http://localhost:7860/api/predict"
80
+
81
+ # This is the format Thunkable will use
82
+ payload = {
83
+ "data": [image_base64, "mobilenet", 0.5, 0.4]
84
+ }
85
+
86
+ headers = {
87
+ "Content-Type": "application/json"
88
+ }
89
+
90
+ print("Testing Thunkable-compatible API call...")
91
+ response = requests.post(url, json=payload, headers=headers)
92
+
93
+ if response.status_code == 200:
94
+ result = response.json()
95
+ print("✅ Thunkable API call successful!")
96
+
97
+ # Parse the response as Thunkable would
98
+ if "data" in result and result["data"]:
99
+ detection_result = result["data"][0]
100
+ print(f"Faces detected: {detection_result.get('total_faces', 0)}")
101
+ print(f"Model used: {detection_result.get('model_used', 'unknown')}")
102
+ print(f"Processing time: {detection_result.get('processing_time', 0):.3f}s")
103
+ else:
104
+ print("❌ Unexpected response format")
105
+ else:
106
+ print(f"❌ API call failed: {response.status_code}")
107
+ print(response.text)
108
+
109
+ if __name__ == "__main__":
110
+ print("Testing RetinaFace Gradio API...")
111
+ print("=" * 50)
112
+
113
+ print("\n1. Health Check:")
114
+ test_health_check()
115
+
116
+ print("\n2. Gradio API Test:")
117
+ test_gradio_api()
118
+
119
+ print("\n3. Thunkable-Compatible API Test:")
120
+ test_direct_api_call()
121
+
122
+ print("\n" + "=" * 50)
123
+ print("Testing complete!")
124
+ print("\nFor Thunkable integration:")
125
+ print("- Use URL: http://localhost:7860/api/predict")
126
+ print("- Method: POST")
127
+ print("- Body format: {\"data\": [\"base64_image\", \"mobilenet\", 0.5, 0.4]}")
128
+ print("- Response will be in: response.data[0]")
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Initialize empty __init__.py files for proper module imports
3
+ """
utils/box_utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def point_form(boxes):
5
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
6
+ representation for comparison to point form ground truth data.
7
+ Args:
8
+ boxes: (tensor) center-size default boxes from priorbox layers.
9
+ Return:
10
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
11
+ """
12
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
13
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
14
+
15
+
16
+ def center_size(boxes):
17
+ """ Convert prior_boxes to (cx, cy, w, h)
18
+ representation for comparison to center-size form ground truth data.
19
+ Args:
20
+ boxes: (tensor) point_form boxes
21
+ Return:
22
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
23
+ """
24
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
25
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
26
+
27
+
28
+ def intersect(box_a, box_b):
29
+ """ We resize both tensors to [A,B,2] without new malloc:
30
+ [A,2] -> [A,1,2] -> [A,B,2]
31
+ [B,2] -> [1,B,2] -> [A,B,2]
32
+ Then we compute the area of intersect between box_a and box_b.
33
+ Args:
34
+ box_a: (tensor) bounding boxes, Shape: [A,4].
35
+ box_b: (tensor) bounding boxes, Shape: [B,4].
36
+ Return:
37
+ (tensor) intersection area, Shape: [A,B].
38
+ """
39
+ A = box_a.size(0)
40
+ B = box_b.size(0)
41
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
42
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
43
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
44
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
45
+ inter = torch.clamp((max_xy - min_xy), min=0)
46
+ return inter[:, :, 0] * inter[:, :, 1]
47
+
48
+
49
+ def jaccard(box_a, box_b):
50
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
51
+ is simply the intersection over union of two boxes. Here we operate on
52
+ ground truth boxes and default boxes.
53
+ E.g.:
54
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
55
+ Args:
56
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
57
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
58
+ Return:
59
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
60
+ """
61
+ inter = intersect(box_a, box_b)
62
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
63
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
64
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
65
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
66
+ union = area_a + area_b - inter
67
+ return inter / union # [A,B]
68
+
69
+
70
+ def matrix_iou(a,b):
71
+ """
72
+ return iou of a and b, numpy version for data augenmentation
73
+ """
74
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
75
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
76
+
77
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
78
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
79
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
80
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
81
+
82
+
83
+ def matrix_iof(a, b):
84
+ """
85
+ return iof of a and b, numpy version for data augenmentation
86
+ """
87
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
88
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
89
+
90
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
91
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
92
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
93
+
94
+
95
+ def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
96
+ """Match each prior box with the ground truth box of the highest jaccard
97
+ overlap, encode the bounding boxes, then return the matched indices
98
+ corresponding to both confidence and location preds.
99
+ Args:
100
+ threshold: (float) The overlap threshold used when mathing boxes.
101
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
102
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
103
+ variances: (tensor) Variances corresponding to each prior coord,
104
+ Shape: [num_priors, 4].
105
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
106
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
107
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
108
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
109
+ landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
110
+ idx: (int) current batch index
111
+ Return:
112
+ The matched indices corresponding to 1)location 2)confidence 3)landm preds.
113
+ """
114
+ # jaccard index
115
+ overlaps = jaccard(
116
+ truths,
117
+ point_form(priors)
118
+ )
119
+ # (Bipartite Matching)
120
+ # [1,num_objects] best prior for each ground truth
121
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
122
+
123
+ # ignore hard gt
124
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
125
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
126
+ if best_prior_idx_filter.shape[0] <= 0:
127
+ loc_t[idx] = 0
128
+ conf_t[idx] = 0
129
+ return
130
+
131
+ # [1,num_priors] best ground truth for each prior
132
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
133
+ best_truth_idx.squeeze_(0)
134
+ best_truth_overlap.squeeze_(0)
135
+ best_prior_idx.squeeze_(1)
136
+ best_prior_idx_filter.squeeze_(1)
137
+ best_prior_overlap.squeeze_(1)
138
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
139
+ # TODO refactor: index best_prior_idx with long tensor
140
+ # ensure every gt matches with its prior of max overlap
141
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是否与某个ground truth匹配
142
+ best_truth_idx[best_prior_idx[j]] = j
143
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每个anchor都分配一个gt
144
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每个anchor都分配一个label
145
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
146
+ loc = encode(matches, priors, variances)
147
+
148
+ matches_landm = landms[best_truth_idx]
149
+ landm = encode_landm(matches_landm, priors, variances)
150
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
151
+ conf_t[idx] = conf # [num_priors] top class label for each prior
152
+ landm_t[idx] = landm
153
+
154
+ def encode(matched, priors, variances):
155
+ """Encode the variances from the priorbox layers into the ground truth boxes
156
+ we have matched (based on jaccard overlap) with the prior boxes.
157
+ Args:
158
+ matched: (tensor) Coords of ground truth for each prior in point-form
159
+ Shape: [num_priors, 4].
160
+ priors: (tensor) Prior boxes in center-offset form
161
+ Shape: [num_priors,4].
162
+ variances: (list[float]) Variances of priorboxes
163
+ Return:
164
+ encoded boxes (tensor), Shape: [num_priors, 4]
165
+ """
166
+
167
+ # dist b/t match center and prior's center
168
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
169
+ # encode variance
170
+ g_cxcy /= (variances[0] * priors[:, 2:])
171
+ # match wh / prior wh
172
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
173
+ g_wh = torch.log(g_wh) / variances[1]
174
+ # return target for smooth_l1_loss
175
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
176
+
177
+ def encode_landm(matched, priors, variances):
178
+ """Encode the variances from the priorbox layers into the ground truth boxes
179
+ we have matched (based on jaccard overlap) with the prior boxes.
180
+ Args:
181
+ matched: (tensor) Coords of ground truth for each prior in point-form
182
+ Shape: [num_priors, 10].
183
+ priors: (tensor) Prior boxes in center-offset form
184
+ Shape: [num_priors,4].
185
+ variances: (list[float]) Variances of priorboxes
186
+ Return:
187
+ encoded landm (tensor), Shape: [num_priors, 10]
188
+ """
189
+
190
+ # dist b/t match center and prior's center
191
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
192
+ priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
193
+ priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
194
+ priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
195
+ priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
196
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
197
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
198
+ # encode variance
199
+ g_cxcy /= (variances[0] * priors[:, :, 2:])
200
+ # g_cxcy /= priors[:, :, 2:]
201
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
202
+ # return target for smooth_l1_loss
203
+ return g_cxcy
204
+
205
+
206
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
207
+ def decode(loc, priors, variances):
208
+ """Decode locations from predictions using priors to undo
209
+ the encoding we did for offset regression at train time.
210
+ Args:
211
+ loc (tensor): location predictions for loc layers,
212
+ Shape: [num_priors,4]
213
+ priors (tensor): Prior boxes in center-offset form.
214
+ Shape: [num_priors,4].
215
+ variances: (list[float]) Variances of priorboxes
216
+ Return:
217
+ decoded bounding box predictions
218
+ """
219
+
220
+ boxes = torch.cat((
221
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
222
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
223
+ boxes[:, :2] -= boxes[:, 2:] / 2
224
+ boxes[:, 2:] += boxes[:, :2]
225
+ return boxes
226
+
227
+ def decode_landm(pre, priors, variances):
228
+ """Decode landm from predictions using priors to undo
229
+ the encoding we did for offset regression at train time.
230
+ Args:
231
+ pre (tensor): landm predictions for loc layers,
232
+ Shape: [num_priors,10]
233
+ priors (tensor): Prior boxes in center-offset form.
234
+ Shape: [num_priors,4].
235
+ variances: (list[float]) Variances of priorboxes
236
+ Return:
237
+ decoded landm predictions
238
+ """
239
+ landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
240
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
241
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
242
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
243
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
244
+ ), dim=1)
245
+ return landms
246
+
247
+
248
+ def log_sum_exp(x):
249
+ """Utility function for computing log_sum_exp while determining
250
+ This will be used to determine unaveraged confidence loss across
251
+ all examples in a batch.
252
+ Args:
253
+ x (Variable(tensor)): conf_preds from conf layers
254
+ """
255
+ x_max = x.data.max()
256
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
utils/prior_box.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from itertools import product as product
3
+ import numpy as np
4
+ from math import ceil
5
+
6
+
7
+ class PriorBox(object):
8
+ def __init__(self, cfg, image_size=None, phase='train'):
9
+ super(PriorBox, self).__init__()
10
+ self.min_sizes = cfg['min_sizes']
11
+ self.steps = cfg['steps']
12
+ self.clip = cfg['clip']
13
+ self.image_size = image_size
14
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
15
+ self.name = "s"
16
+
17
+ def forward(self):
18
+ anchors = []
19
+ for k, f in enumerate(self.feature_maps):
20
+ min_sizes = self.min_sizes[k]
21
+ for i, j in product(range(f[0]), range(f[1])):
22
+ for min_size in min_sizes:
23
+ s_kx = min_size / self.image_size[1]
24
+ s_ky = min_size / self.image_size[0]
25
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
26
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
27
+ for cy, cx in product(dense_cy, dense_cx):
28
+ anchors += [cx, cy, s_kx, s_ky]
29
+
30
+ # back to torch land
31
+ output = torch.Tensor(anchors).view(-1, 4)
32
+ if self.clip:
33
+ output.clamp_(max=1, min=0)
34
+ return output
utils/py_cpu_nms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def py_cpu_nms(dets, thresh):
4
+ """Pure Python NMS baseline."""
5
+ x1 = dets[:, 0]
6
+ y1 = dets[:, 1]
7
+ x2 = dets[:, 2]
8
+ y2 = dets[:, 3]
9
+ scores = dets[:, 4]
10
+
11
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
12
+ order = scores.argsort()[::-1]
13
+
14
+ keep = []
15
+ while order.size > 0:
16
+ i = order[0]
17
+ keep.append(i)
18
+ xx1 = np.maximum(x1[i], x1[order[1:]])
19
+ yy1 = np.maximum(y1[i], y1[order[1:]])
20
+ xx2 = np.minimum(x2[i], x2[order[1:]])
21
+ yy2 = np.minimum(y2[i], y2[order[1:]])
22
+
23
+ w = np.maximum(0.0, xx2 - xx1 + 1)
24
+ h = np.maximum(0.0, yy2 - yy1 + 1)
25
+ inter = w * h
26
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
27
+
28
+ inds = np.where(ovr <= thresh)[0]
29
+ order = order[inds + 1]
30
+
31
+ return keep