Spaces:
Runtime error
Runtime error
Commit ·
b10b0ba
1
Parent(s): e7568be
Deploy RetinaFace face detection API with Gradio SDK
Browse files- Added RetinaFace face detection models (MobileNet and ResNet)
- Implemented Gradio-based web interface and API endpoints
- Added utility modules for face detection processing
- Included model files: mobilenet0.25_Final.pth and Resnet50_Final.pth
- Added comprehensive documentation and deployment guides
- Added Thunkable integration examples and test scripts
- Ready for deployment on Hugging Face Spaces
- DEPLOYMENT_GUIDE.md +238 -0
- README.md +5 -10
- Resnet50_Final.pth +3 -0
- THUNKABLE_EXAMPLES.md +347 -0
- app.py +458 -0
- mobilenet0.25_Final.pth +3 -0
- models/__init__.py +3 -0
- models/retinaface.py +316 -0
- requirements.txt +10 -0
- start.bat +16 -0
- start.sh +15 -0
- test_api.py +128 -0
- utils/__init__.py +3 -0
- utils/box_utils.py +256 -0
- utils/prior_box.py +34 -0
- utils/py_cpu_nms.py +31 -0
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RetinaFace Face Detection API
|
| 2 |
+
|
| 3 |
+
A Gradio-based face detection service using RetinaFace models (MobileNet and ResNet backbones) deployed on Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- 🔥 **Dual Model Support**: MobileNet (fast) and ResNet (accurate) backbones
|
| 8 |
+
- 📱 **Thunkable Compatible**: API endpoints for mobile app integration
|
| 9 |
+
- ⚡ **Real-time Detection**: Web interface and API endpoints
|
| 10 |
+
- 🎨 **Interactive UI**: Gradio web interface for easy testing
|
| 11 |
+
- 🚀 **Serverless**: Deployed on Hugging Face Spaces for free
|
| 12 |
+
|
| 13 |
+
## Web Interface
|
| 14 |
+
|
| 15 |
+
Access the interactive web interface at your Hugging Face Space URL:
|
| 16 |
+
- Image upload and detection
|
| 17 |
+
- Model selection (MobileNet/ResNet)
|
| 18 |
+
- Confidence threshold adjustment
|
| 19 |
+
- Real-time results visualization
|
| 20 |
+
- API testing interface
|
| 21 |
+
|
| 22 |
+
## API Endpoints
|
| 23 |
+
|
| 24 |
+
### 1. Gradio API Endpoint
|
| 25 |
+
```
|
| 26 |
+
POST /api/predict
|
| 27 |
+
```
|
| 28 |
+
Main API endpoint compatible with Thunkable and other applications.
|
| 29 |
+
|
| 30 |
+
**Request Body:**
|
| 31 |
+
```json
|
| 32 |
+
{
|
| 33 |
+
"data": [
|
| 34 |
+
"base64_encoded_image_string",
|
| 35 |
+
"mobilenet",
|
| 36 |
+
0.5,
|
| 37 |
+
0.4
|
| 38 |
+
]
|
| 39 |
+
}
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
**Response:**
|
| 43 |
+
```json
|
| 44 |
+
{
|
| 45 |
+
"data": [
|
| 46 |
+
{
|
| 47 |
+
"faces": [
|
| 48 |
+
{
|
| 49 |
+
"bbox": {"x1": 100, "y1": 120, "x2": 200, "y2": 220},
|
| 50 |
+
"confidence": 0.95,
|
| 51 |
+
"landmarks": {
|
| 52 |
+
"right_eye": [130, 150],
|
| 53 |
+
"left_eye": [170, 150],
|
| 54 |
+
"nose": [150, 170],
|
| 55 |
+
"right_mouth": [135, 190],
|
| 56 |
+
"left_mouth": [165, 190]
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
],
|
| 60 |
+
"processing_time": 0.1,
|
| 61 |
+
"model_used": "mobilenet",
|
| 62 |
+
"total_faces": 1
|
| 63 |
+
}
|
| 64 |
+
]
|
| 65 |
+
}
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Deployment Instructions
|
| 69 |
+
|
| 70 |
+
### 1. Hugging Face Spaces Deployment
|
| 71 |
+
|
| 72 |
+
1. **Create a new Space on Hugging Face:**
|
| 73 |
+
- Go to https://huggingface.co/spaces
|
| 74 |
+
- Click "Create new Space"
|
| 75 |
+
- Choose "Gradio" as SDK
|
| 76 |
+
- Set SDK version to 4.44.0
|
| 77 |
+
- Set visibility to "Public"
|
| 78 |
+
|
| 79 |
+
2. **Upload your files:**
|
| 80 |
+
```
|
| 81 |
+
├── app.py # Main Gradio application
|
| 82 |
+
├── requirements.txt # Python dependencies
|
| 83 |
+
├── README.md # HF Spaces configuration
|
| 84 |
+
├── mobilenet0.25_Final.pth # MobileNet model weights
|
| 85 |
+
├── Resnet50_Final.pth # ResNet model weights
|
| 86 |
+
├── models/
|
| 87 |
+
│ └── retinaface.py # RetinaFace model architecture
|
| 88 |
+
└── utils/
|
| 89 |
+
├── box_utils.py # Bounding box utilities
|
| 90 |
+
├── prior_box.py # Anchor box generation
|
| 91 |
+
└── py_cpu_nms.py # Non-maximum suppression
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
3. **Your Space will automatically build and deploy!**
|
| 95 |
+
|
| 96 |
+
### 2. Local Testing
|
| 97 |
+
|
| 98 |
+
1. **Install dependencies:**
|
| 99 |
+
```bash
|
| 100 |
+
pip install -r requirements.txt
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
2. **Run locally:**
|
| 104 |
+
```bash
|
| 105 |
+
python app.py
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
3. **Test the API:**
|
| 109 |
+
```bash
|
| 110 |
+
python test_api.py
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
4. **Access the web interface:**
|
| 114 |
+
Open http://localhost:7860 in your browser
|
| 115 |
+
|
| 116 |
+
## Thunkable Integration
|
| 117 |
+
|
| 118 |
+
### 1. Web API Component Setup
|
| 119 |
+
```
|
| 120 |
+
URL: https://your-username-retinaface-api.hf.space/api/predict
|
| 121 |
+
Method: POST
|
| 122 |
+
Headers: Content-Type: application/json
|
| 123 |
+
Body: {
|
| 124 |
+
"data": [
|
| 125 |
+
"{{base64_image}}",
|
| 126 |
+
"mobilenet",
|
| 127 |
+
0.5,
|
| 128 |
+
0.4
|
| 129 |
+
]
|
| 130 |
+
}
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### 2. Response Handling in Thunkable
|
| 134 |
+
```
|
| 135 |
+
When Web API receives data:
|
| 136 |
+
Set app variable "apiResponse" to response body
|
| 137 |
+
Set app variable "detectionData" to get property "data" of apiResponse
|
| 138 |
+
Set app variable "faces" to get property "faces" of detectionData[0]
|
| 139 |
+
Set app variable "faceCount" to get property "total_faces" of detectionData[0]
|
| 140 |
+
|
| 141 |
+
If faceCount > 0:
|
| 142 |
+
For each face in faces:
|
| 143 |
+
// Process face data (bbox, confidence, landmarks)
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### 3. Base64 Image Conversion
|
| 147 |
+
```
|
| 148 |
+
// In Thunkable, convert camera image to base64
|
| 149 |
+
Set app variable "imageBase64" to
|
| 150 |
+
call CloudinaryAPI.convertToBase64
|
| 151 |
+
mediaDB = Camera1.Picture
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## Model Performance
|
| 155 |
+
|
| 156 |
+
| Model | Speed | Accuracy | Use Case |
|
| 157 |
+
|-------|-------|----------|----------|
|
| 158 |
+
| MobileNet | Fast | Good | Real-time mobile apps |
|
| 159 |
+
| ResNet50 | Slower | High | High-accuracy applications |
|
| 160 |
+
|
| 161 |
+
## API Testing
|
| 162 |
+
|
| 163 |
+
Use the built-in API testing interface in the Gradio app:
|
| 164 |
+
1. Go to the "📊 API Testing" tab
|
| 165 |
+
2. Paste your base64 encoded image
|
| 166 |
+
3. Select model and parameters
|
| 167 |
+
4. Click "🧪 Test API"
|
| 168 |
+
5. View the JSON response
|
| 169 |
+
|
| 170 |
+
## Error Handling
|
| 171 |
+
|
| 172 |
+
The API includes comprehensive error handling:
|
| 173 |
+
- Invalid image data validation
|
| 174 |
+
- Model loading verification
|
| 175 |
+
- Detailed error responses in JSON format
|
| 176 |
+
|
| 177 |
+
## Advantages of Gradio SDK
|
| 178 |
+
|
| 179 |
+
✅ **Web Interface**: Built-in UI for testing and demonstration
|
| 180 |
+
✅ **API Endpoints**: Automatic API generation at `/api/predict`
|
| 181 |
+
✅ **Easy Deployment**: No Docker configuration needed
|
| 182 |
+
✅ **Real-time Testing**: Interactive interface for immediate feedback
|
| 183 |
+
✅ **Documentation**: Built-in API documentation
|
| 184 |
+
✅ **Mobile Friendly**: Responsive web interface
|
| 185 |
+
|
| 186 |
+
## Limitations
|
| 187 |
+
|
| 188 |
+
- **File Size**: Max upload size determined by Hugging Face Spaces
|
| 189 |
+
- **Concurrent Requests**: Subject to Hugging Face Spaces limits
|
| 190 |
+
- **Cold Starts**: First request may take longer due to model loading
|
| 191 |
+
- **Processing Time**: Heavy models may timeout on free tier
|
| 192 |
+
|
| 193 |
+
## Example Integration Code
|
| 194 |
+
|
| 195 |
+
### JavaScript/Thunkable
|
| 196 |
+
```javascript
|
| 197 |
+
const response = await fetch('https://your-space.hf.space/api/predict', {
|
| 198 |
+
method: 'POST',
|
| 199 |
+
headers: { 'Content-Type': 'application/json' },
|
| 200 |
+
body: JSON.stringify({
|
| 201 |
+
data: [base64Image, "mobilenet", 0.5, 0.4]
|
| 202 |
+
})
|
| 203 |
+
});
|
| 204 |
+
|
| 205 |
+
const result = await response.json();
|
| 206 |
+
const faces = result.data[0].faces;
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
### Python
|
| 210 |
+
```python
|
| 211 |
+
import requests
|
| 212 |
+
import base64
|
| 213 |
+
|
| 214 |
+
# Convert image to base64
|
| 215 |
+
with open('image.jpg', 'rb') as f:
|
| 216 |
+
image_b64 = base64.b64encode(f.read()).decode()
|
| 217 |
+
|
| 218 |
+
# Make API call
|
| 219 |
+
response = requests.post(
|
| 220 |
+
'https://your-space.hf.space/api/predict',
|
| 221 |
+
json={"data": [image_b64, "mobilenet", 0.5, 0.4]}
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
result = response.json()
|
| 225 |
+
faces = result["data"][0]["faces"]
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
## Support
|
| 229 |
+
|
| 230 |
+
For issues or questions:
|
| 231 |
+
1. Check the web interface at your Space URL
|
| 232 |
+
2. Test locally using the provided test script
|
| 233 |
+
3. Use the built-in API testing tab in Gradio
|
| 234 |
+
4. Verify model files are correctly uploaded
|
| 235 |
+
|
| 236 |
+
## License
|
| 237 |
+
|
| 238 |
+
Apache 2.0
|
README.md
CHANGED
|
@@ -1,14 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
emoji: 🔥
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license:
|
| 11 |
-
short_description: Retinaface face detection mobile0.25 + ResNet
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
title: RetinaFace Face Detection API
|
| 2 |
+
emoji: 😊
|
|
|
|
| 3 |
colorFrom: blue
|
| 4 |
+
colorTo: red
|
| 5 |
sdk: gradio
|
| 6 |
+
sdk_version: 4.44.0
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
| 9 |
+
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
Resnet50_Final.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
| 3 |
+
size 109497761
|
THUNKABLE_EXAMPLES.md
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Thunkable Integration Examples for Gradio API
|
| 2 |
+
|
| 3 |
+
This file contains examples of how to integrate the RetinaFace Gradio API with Thunkable.
|
| 4 |
+
|
| 5 |
+
## 1. Camera Capture and Face Detection
|
| 6 |
+
|
| 7 |
+
### Blocks Setup:
|
| 8 |
+
```
|
| 9 |
+
1. Camera1 → TakePicture
|
| 10 |
+
2. Convert image to base64
|
| 11 |
+
3. Make API call to Gradio endpoint
|
| 12 |
+
4. Process results
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
### Base64 Conversion Block:
|
| 16 |
+
```
|
| 17 |
+
Set app variable "imageBase64" to
|
| 18 |
+
call CloudinaryAPI.convertToBase64
|
| 19 |
+
mediaDB = Camera1.Picture
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### API Call Block:
|
| 23 |
+
```
|
| 24 |
+
Web API1:
|
| 25 |
+
- URL: https://your-space-name.hf.space/api/predict
|
| 26 |
+
- Method: POST
|
| 27 |
+
- Headers: {"Content-Type": "application/json"}
|
| 28 |
+
- Body: {
|
| 29 |
+
"data": [
|
| 30 |
+
get app variable "imageBase64",
|
| 31 |
+
"mobilenet",
|
| 32 |
+
0.5,
|
| 33 |
+
0.4
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### Response Handling:
|
| 39 |
+
```
|
| 40 |
+
When Web API1 receives data:
|
| 41 |
+
Set app variable "apiResponse" to responseBody
|
| 42 |
+
Set app variable "detectionData" to get property "data" of apiResponse
|
| 43 |
+
Set app variable "resultData" to get item 1 of list detectionData
|
| 44 |
+
Set app variable "faces" to get property "faces" of resultData
|
| 45 |
+
Set app variable "faceCount" to get property "total_faces" of resultData
|
| 46 |
+
|
| 47 |
+
If faceCount > 0:
|
| 48 |
+
For each item "face" in list "faces":
|
| 49 |
+
Set app variable "confidence" to get property "confidence" of object "face"
|
| 50 |
+
Set app variable "bbox" to get property "bbox" of object "face"
|
| 51 |
+
|
| 52 |
+
// Draw bounding box or show results
|
| 53 |
+
Set Label1.Text to join("Found face with confidence: ", confidence)
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## 2. API Response Structure
|
| 57 |
+
|
| 58 |
+
### Gradio API Response Format:
|
| 59 |
+
```json
|
| 60 |
+
{
|
| 61 |
+
"data": [
|
| 62 |
+
{
|
| 63 |
+
"faces": [
|
| 64 |
+
{
|
| 65 |
+
"bbox": {"x1": 100, "y1": 120, "x2": 200, "y2": 220},
|
| 66 |
+
"confidence": 0.95,
|
| 67 |
+
"landmarks": {
|
| 68 |
+
"right_eye": [130, 150],
|
| 69 |
+
"left_eye": [170, 150],
|
| 70 |
+
"nose": [150, 170],
|
| 71 |
+
"right_mouth": [135, 190],
|
| 72 |
+
"left_mouth": [165, 190]
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
],
|
| 76 |
+
"processing_time": 0.1,
|
| 77 |
+
"model_used": "mobilenet",
|
| 78 |
+
"total_faces": 1
|
| 79 |
+
}
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Extracting Data in Thunkable:
|
| 85 |
+
```
|
| 86 |
+
// Get the main detection result
|
| 87 |
+
Set app variable "result" to get item 1 of list (get property "data" of responseBody)
|
| 88 |
+
|
| 89 |
+
// Extract face information
|
| 90 |
+
Set app variable "faces" to get property "faces" of result
|
| 91 |
+
Set app variable "totalFaces" to get property "total_faces" of result
|
| 92 |
+
Set app variable "processingTime" to get property "processing_time" of result
|
| 93 |
+
Set app variable "modelUsed" to get property "model_used" of result
|
| 94 |
+
|
| 95 |
+
// For each face detected
|
| 96 |
+
For each item "face" in list "faces":
|
| 97 |
+
Set app variable "boundingBox" to get property "bbox" of face
|
| 98 |
+
Set app variable "confidence" to get property "confidence" of face
|
| 99 |
+
Set app variable "landmarks" to get property "landmarks" of face
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## 3. Error Handling
|
| 103 |
+
|
| 104 |
+
### Connection Error:
|
| 105 |
+
```
|
| 106 |
+
When Web API1 has error:
|
| 107 |
+
Set Label_Error.Text to "Failed to connect to face detection service"
|
| 108 |
+
Set Label_Error.Visible to true
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### API Error Response:
|
| 112 |
+
```
|
| 113 |
+
When Web API1 receives data:
|
| 114 |
+
If response status ≠ 200:
|
| 115 |
+
Set Label_Error.Text to "API Error: Check your image format"
|
| 116 |
+
Else:
|
| 117 |
+
// Check for error in response data
|
| 118 |
+
Set app variable "result" to get item 1 of list (get property "data" of responseBody)
|
| 119 |
+
If get property "error" of result ≠ null:
|
| 120 |
+
Set Label_Error.Text to get property "error" of result
|
| 121 |
+
Else:
|
| 122 |
+
// Process successful response
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## 4. Real-time Detection Loop
|
| 126 |
+
|
| 127 |
+
### Continuous Detection:
|
| 128 |
+
```
|
| 129 |
+
When Screen opens:
|
| 130 |
+
Set app variable "isDetecting" to true
|
| 131 |
+
Call function "startDetectionLoop"
|
| 132 |
+
|
| 133 |
+
Function startDetectionLoop:
|
| 134 |
+
While app variable "isDetecting" = true:
|
| 135 |
+
Camera1.TakePicture
|
| 136 |
+
Wait 1 second // Adjust for performance - Gradio may be slower than FastAPI
|
| 137 |
+
|
| 138 |
+
When Camera1.AfterPicture:
|
| 139 |
+
If app variable "isDetecting" = true:
|
| 140 |
+
Call API for detection
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## 5. Performance Optimization
|
| 144 |
+
|
| 145 |
+
### Image Compression:
|
| 146 |
+
```
|
| 147 |
+
Before API call:
|
| 148 |
+
Set app variable "compressedImage" to
|
| 149 |
+
call ImageUtils.compress
|
| 150 |
+
image = Camera1.Picture
|
| 151 |
+
quality = 0.7 // Reduce file size for faster upload
|
| 152 |
+
maxWidth = 640 // Gradio handles smaller images better
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### Model Selection for Performance:
|
| 156 |
+
```
|
| 157 |
+
// For real-time applications, always use MobileNet
|
| 158 |
+
Set app variable "modelType" to "mobilenet"
|
| 159 |
+
|
| 160 |
+
// For high-accuracy single shots, use ResNet
|
| 161 |
+
Set app variable "modelType" to "resnet"
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
## 6. Complete API Integration Function
|
| 165 |
+
|
| 166 |
+
### Thunkable Function: DetectFaces
|
| 167 |
+
```
|
| 168 |
+
Function DetectFaces(imageToAnalyze, selectedModel, confidenceLevel):
|
| 169 |
+
|
| 170 |
+
// Convert image to base64
|
| 171 |
+
Set local variable "imageBase64" to
|
| 172 |
+
call CloudinaryAPI.convertToBase64
|
| 173 |
+
mediaDB = imageToAnalyze
|
| 174 |
+
|
| 175 |
+
// Prepare API request
|
| 176 |
+
Set local variable "requestData" to create object with:
|
| 177 |
+
"data" = create list with items:
|
| 178 |
+
- imageBase64
|
| 179 |
+
- selectedModel
|
| 180 |
+
- confidenceLevel
|
| 181 |
+
- 0.4 // NMS threshold
|
| 182 |
+
|
| 183 |
+
// Make API call
|
| 184 |
+
Call Web API1.POST with:
|
| 185 |
+
url = "https://your-space-name.hf.space/api/predict"
|
| 186 |
+
body = requestData
|
| 187 |
+
headers = create object with "Content-Type" = "application/json"
|
| 188 |
+
|
| 189 |
+
// Return to calling function
|
| 190 |
+
Return "API call initiated"
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### Response Handler Function:
|
| 194 |
+
```
|
| 195 |
+
Function ProcessDetectionResponse(responseBody):
|
| 196 |
+
|
| 197 |
+
// Extract main result
|
| 198 |
+
Set local variable "detectionResult" to get item 1 of list (get property "data" of responseBody)
|
| 199 |
+
|
| 200 |
+
// Check for errors
|
| 201 |
+
If get property "error" of detectionResult ≠ null:
|
| 202 |
+
Set Label_Status.Text to get property "error" of detectionResult
|
| 203 |
+
Return false
|
| 204 |
+
|
| 205 |
+
// Process successful detection
|
| 206 |
+
Set app variable "lastDetectionFaces" to get property "faces" of detectionResult
|
| 207 |
+
Set app variable "lastDetectionCount" to get property "total_faces" of detectionResult
|
| 208 |
+
Set app variable "lastProcessingTime" to get property "processing_time" of detectionResult
|
| 209 |
+
|
| 210 |
+
// Update UI
|
| 211 |
+
Set Label_FaceCount.Text to join("Faces detected: ", lastDetectionCount)
|
| 212 |
+
Set Label_ProcessingTime.Text to join("Processing time: ", lastProcessingTime, "s")
|
| 213 |
+
|
| 214 |
+
Return true
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
## 7. Advanced Features
|
| 218 |
+
|
| 219 |
+
### Face Landmark Visualization:
|
| 220 |
+
```
|
| 221 |
+
For each face in lastDetectionFaces:
|
| 222 |
+
Set local variable "landmarks" to get property "landmarks" of face
|
| 223 |
+
|
| 224 |
+
// Extract landmark coordinates
|
| 225 |
+
Set local variable "rightEye" to get property "right_eye" of landmarks
|
| 226 |
+
Set local variable "leftEye" to get property "left_eye" of landmarks
|
| 227 |
+
Set local variable "nose" to get property "nose" of landmarks
|
| 228 |
+
Set local variable "rightMouth" to get property "right_mouth" of landmarks
|
| 229 |
+
Set local variable "leftMouth" to get property "left_mouth" of landmarks
|
| 230 |
+
|
| 231 |
+
// Draw landmarks (if using drawing components)
|
| 232 |
+
Set Circle_RightEye.X to get item 1 of rightEye
|
| 233 |
+
Set Circle_RightEye.Y to get item 2 of rightEye
|
| 234 |
+
// ... repeat for other landmarks
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
### Confidence Filtering:
|
| 238 |
+
```
|
| 239 |
+
Function FilterHighConfidenceFaces(allFaces, minConfidence):
|
| 240 |
+
Set local variable "filteredFaces" to create empty list
|
| 241 |
+
|
| 242 |
+
For each item "face" in list allFaces:
|
| 243 |
+
Set local variable "confidence" to get property "confidence" of face
|
| 244 |
+
If confidence ≥ minConfidence:
|
| 245 |
+
Add face to filteredFaces
|
| 246 |
+
|
| 247 |
+
Return filteredFaces
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## 8. UI Components for Gradio Integration
|
| 251 |
+
|
| 252 |
+
### Required Components:
|
| 253 |
+
```
|
| 254 |
+
- Camera1 (for image capture)
|
| 255 |
+
- Button_Detect (trigger detection)
|
| 256 |
+
- Label_Status (show current status)
|
| 257 |
+
- Label_FaceCount (display number of faces)
|
| 258 |
+
- Label_ProcessingTime (show API response time)
|
| 259 |
+
- Label_Error (error messages)
|
| 260 |
+
- WebAPI1 (API communication)
|
| 261 |
+
- Dropdown_Model (model selection)
|
| 262 |
+
- Slider_Confidence (confidence threshold)
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
### Component Properties:
|
| 266 |
+
```
|
| 267 |
+
Button_Detect:
|
| 268 |
+
- Text: "🔍 Detect Faces"
|
| 269 |
+
- Enabled: true when camera has image
|
| 270 |
+
|
| 271 |
+
Label_Status:
|
| 272 |
+
- Text: "Ready to detect faces"
|
| 273 |
+
- Font size: 16
|
| 274 |
+
|
| 275 |
+
Dropdown_Model:
|
| 276 |
+
- Options: ["mobilenet", "resnet"]
|
| 277 |
+
- Default: "mobilenet"
|
| 278 |
+
|
| 279 |
+
Slider_Confidence:
|
| 280 |
+
- Min: 0.1
|
| 281 |
+
- Max: 1.0
|
| 282 |
+
- Default: 0.5
|
| 283 |
+
- Step: 0.1
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
## 9. Testing Your Gradio Integration
|
| 287 |
+
|
| 288 |
+
### Test Checklist:
|
| 289 |
+
- [ ] Camera permission granted
|
| 290 |
+
- [ ] Internet connection available
|
| 291 |
+
- [ ] Gradio API endpoint accessible (test in browser first)
|
| 292 |
+
- [ ] Base64 conversion working correctly
|
| 293 |
+
- [ ] Response parsing handles Gradio format
|
| 294 |
+
- [ ] Error handling for API failures
|
| 295 |
+
- [ ] UI updates with detection results
|
| 296 |
+
|
| 297 |
+
### Debug Tips:
|
| 298 |
+
1. Test Gradio web interface first at your Space URL
|
| 299 |
+
2. Use the built-in "📊 API Testing" tab in Gradio
|
| 300 |
+
3. Verify base64 encoding doesn't include data URL prefix
|
| 301 |
+
4. Check that response format matches expected structure
|
| 302 |
+
5. Monitor processing times (Gradio may be slower than FastAPI)
|
| 303 |
+
|
| 304 |
+
## 10. Production Considerations
|
| 305 |
+
|
| 306 |
+
### Performance:
|
| 307 |
+
- Gradio apps may have slightly higher latency than pure FastAPI
|
| 308 |
+
- Use MobileNet for real-time applications
|
| 309 |
+
- Consider image compression for faster uploads
|
| 310 |
+
- Implement proper loading indicators
|
| 311 |
+
|
| 312 |
+
### Reliability:
|
| 313 |
+
- Handle Gradio app cold starts (first request may timeout)
|
| 314 |
+
- Implement retry logic for failed requests
|
| 315 |
+
- Cache successful results when appropriate
|
| 316 |
+
- Provide fallback options for offline scenarios
|
| 317 |
+
|
| 318 |
+
### User Experience:
|
| 319 |
+
- Show clear loading states during API calls
|
| 320 |
+
- Provide informative error messages
|
| 321 |
+
- Allow users to switch between models
|
| 322 |
+
- Display confidence scores and processing times
|
| 323 |
+
|
| 324 |
+
## 11. Sample Thunkable Blocks Layout
|
| 325 |
+
|
| 326 |
+
### Main Detection Flow:
|
| 327 |
+
```
|
| 328 |
+
When Button_Detect.Click:
|
| 329 |
+
→ Set Label_Status.Text to "Capturing image..."
|
| 330 |
+
→ Camera1.TakePicture
|
| 331 |
+
|
| 332 |
+
When Camera1.AfterPicture:
|
| 333 |
+
→ Set Label_Status.Text to "Converting to base64..."
|
| 334 |
+
→ Call CloudinaryAPI.convertToBase64
|
| 335 |
+
|
| 336 |
+
When CloudinaryAPI.GotBase64:
|
| 337 |
+
→ Set Label_Status.Text to "Detecting faces..."
|
| 338 |
+
→ Set app variable "imageB64" to base64Result
|
| 339 |
+
→ Call function DetectFaces
|
| 340 |
+
|
| 341 |
+
When WebAPI1.GotText:
|
| 342 |
+
→ Set Label_Status.Text to "Processing results..."
|
| 343 |
+
→ Call function ProcessDetectionResponse
|
| 344 |
+
→ Set Label_Status.Text to "Detection complete!"
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
This comprehensive guide should help you successfully integrate your Gradio-based RetinaFace API with Thunkable!
|
app.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import base64
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
import tempfile
|
| 9 |
+
import time
|
| 10 |
+
from PIL import Image, ImageDraw
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
# Import RetinaFace model components
|
| 14 |
+
from models.retinaface import RetinaFace
|
| 15 |
+
from utils.prior_box import PriorBox
|
| 16 |
+
from utils.py_cpu_nms import py_cpu_nms
|
| 17 |
+
from utils.box_utils import decode, decode_landm
|
| 18 |
+
|
| 19 |
+
# Global variables for models
|
| 20 |
+
mobilenet_model = None
|
| 21 |
+
resnet_model = None
|
| 22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
|
| 24 |
+
def load_models():
|
| 25 |
+
"""Load both MobileNet and ResNet RetinaFace models"""
|
| 26 |
+
global mobilenet_model, resnet_model
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
# Load MobileNet model
|
| 30 |
+
mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test')
|
| 31 |
+
mobilenet_model.load_state_dict(torch.load('mobilenet0.25_Final.pth', map_location=device))
|
| 32 |
+
mobilenet_model.eval()
|
| 33 |
+
mobilenet_model = mobilenet_model.to(device)
|
| 34 |
+
|
| 35 |
+
# Load ResNet model
|
| 36 |
+
resnet_model = RetinaFace(cfg=resnet_cfg, phase='test')
|
| 37 |
+
resnet_model.load_state_dict(torch.load('Resnet50_Final.pth', map_location=device))
|
| 38 |
+
resnet_model.eval()
|
| 39 |
+
resnet_model = resnet_model.to(device)
|
| 40 |
+
|
| 41 |
+
print("Models loaded successfully!")
|
| 42 |
+
return "✅ Models loaded successfully!"
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
error_msg = f"❌ Error loading models: {e}"
|
| 46 |
+
print(error_msg)
|
| 47 |
+
return error_msg
|
| 48 |
+
|
| 49 |
+
# Model configurations
|
| 50 |
+
mobilenet_cfg = {
|
| 51 |
+
'name': 'mobilenet0.25',
|
| 52 |
+
'min_sizes': [[16, 32], [64, 128], [256, 512]],
|
| 53 |
+
'steps': [8, 16, 32],
|
| 54 |
+
'variance': [0.1, 0.2],
|
| 55 |
+
'clip': False,
|
| 56 |
+
'loc_weight': 2.0,
|
| 57 |
+
'gpu_train': True,
|
| 58 |
+
'batch_size': 32,
|
| 59 |
+
'ngpu': 1,
|
| 60 |
+
'epoch': 250,
|
| 61 |
+
'decay1': 190,
|
| 62 |
+
'decay2': 220,
|
| 63 |
+
'image_size': 640,
|
| 64 |
+
'pretrain': True,
|
| 65 |
+
'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
|
| 66 |
+
'in_channel': 32,
|
| 67 |
+
'out_channel': 64
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
resnet_cfg = {
|
| 71 |
+
'name': 'Resnet50',
|
| 72 |
+
'min_sizes': [[16, 32], [64, 128], [256, 512]],
|
| 73 |
+
'steps': [8, 16, 32],
|
| 74 |
+
'variance': [0.1, 0.2],
|
| 75 |
+
'clip': False,
|
| 76 |
+
'loc_weight': 2.0,
|
| 77 |
+
'gpu_train': True,
|
| 78 |
+
'batch_size': 24,
|
| 79 |
+
'ngpu': 4,
|
| 80 |
+
'epoch': 100,
|
| 81 |
+
'decay1': 70,
|
| 82 |
+
'decay2': 90,
|
| 83 |
+
'image_size': 840,
|
| 84 |
+
'pretrain': True,
|
| 85 |
+
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
|
| 86 |
+
'in_channel': 256,
|
| 87 |
+
'out_channel': 256
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def detect_faces_core(image, model, cfg, confidence_threshold=0.02, nms_threshold=0.4):
|
| 91 |
+
"""Core face detection function"""
|
| 92 |
+
start_time = time.time()
|
| 93 |
+
|
| 94 |
+
# Preprocessing
|
| 95 |
+
img = np.float32(image)
|
| 96 |
+
im_height, im_width, _ = img.shape
|
| 97 |
+
scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
| 98 |
+
img -= (104, 117, 123)
|
| 99 |
+
img = img.transpose(2, 0, 1)
|
| 100 |
+
img = torch.from_numpy(img).unsqueeze(0)
|
| 101 |
+
img = img.to(device)
|
| 102 |
+
scale = scale.to(device)
|
| 103 |
+
|
| 104 |
+
# Forward pass
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
loc, conf, landms = model(img)
|
| 107 |
+
|
| 108 |
+
# Post-processing
|
| 109 |
+
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
|
| 110 |
+
priors = priorbox.forward()
|
| 111 |
+
priors = priors.to(device)
|
| 112 |
+
prior_data = priors.data
|
| 113 |
+
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
|
| 114 |
+
boxes = boxes * scale / 1
|
| 115 |
+
boxes = boxes.cpu().numpy()
|
| 116 |
+
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
| 117 |
+
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
|
| 118 |
+
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
|
| 119 |
+
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
|
| 120 |
+
img.shape[3], img.shape[2]])
|
| 121 |
+
scale1 = scale1.to(device)
|
| 122 |
+
landms = landms * scale1 / 1
|
| 123 |
+
landms = landms.cpu().numpy()
|
| 124 |
+
|
| 125 |
+
# Ignore low scores
|
| 126 |
+
inds = np.where(scores > confidence_threshold)[0]
|
| 127 |
+
boxes = boxes[inds]
|
| 128 |
+
landms = landms[inds]
|
| 129 |
+
scores = scores[inds]
|
| 130 |
+
|
| 131 |
+
# Keep top-K before NMS
|
| 132 |
+
order = scores.argsort()[::-1][:5000]
|
| 133 |
+
boxes = boxes[order]
|
| 134 |
+
landms = landms[order]
|
| 135 |
+
scores = scores[order]
|
| 136 |
+
|
| 137 |
+
# Do NMS
|
| 138 |
+
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
| 139 |
+
keep = py_cpu_nms(dets, nms_threshold)
|
| 140 |
+
dets = dets[keep, :]
|
| 141 |
+
landms = landms[keep]
|
| 142 |
+
|
| 143 |
+
# Format results
|
| 144 |
+
faces = []
|
| 145 |
+
for i in range(dets.shape[0]):
|
| 146 |
+
if dets[i, 4] < confidence_threshold:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
face = {
|
| 150 |
+
"bbox": {
|
| 151 |
+
"x1": float(dets[i, 0]),
|
| 152 |
+
"y1": float(dets[i, 1]),
|
| 153 |
+
"x2": float(dets[i, 2]),
|
| 154 |
+
"y2": float(dets[i, 3])
|
| 155 |
+
},
|
| 156 |
+
"confidence": float(dets[i, 4]),
|
| 157 |
+
"landmarks": {
|
| 158 |
+
"right_eye": [float(landms[i, 0]), float(landms[i, 1])],
|
| 159 |
+
"left_eye": [float(landms[i, 2]), float(landms[i, 3])],
|
| 160 |
+
"nose": [float(landms[i, 4]), float(landms[i, 5])],
|
| 161 |
+
"right_mouth": [float(landms[i, 6]), float(landms[i, 7])],
|
| 162 |
+
"left_mouth": [float(landms[i, 8]), float(landms[i, 9])]
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
faces.append(face)
|
| 166 |
+
|
| 167 |
+
processing_time = time.time() - start_time
|
| 168 |
+
return faces, processing_time
|
| 169 |
+
|
| 170 |
+
def draw_faces_on_image(image, faces):
|
| 171 |
+
"""Draw bounding boxes and landmarks on image"""
|
| 172 |
+
if isinstance(image, np.ndarray):
|
| 173 |
+
# Convert numpy array to PIL Image
|
| 174 |
+
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 175 |
+
|
| 176 |
+
draw = ImageDraw.Draw(image)
|
| 177 |
+
|
| 178 |
+
for face in faces:
|
| 179 |
+
bbox = face["bbox"]
|
| 180 |
+
confidence = face["confidence"]
|
| 181 |
+
landmarks = face["landmarks"]
|
| 182 |
+
|
| 183 |
+
# Draw bounding box
|
| 184 |
+
draw.rectangle([bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]],
|
| 185 |
+
outline="red", width=2)
|
| 186 |
+
|
| 187 |
+
# Draw confidence score
|
| 188 |
+
draw.text((bbox["x1"], bbox["y1"] - 15),
|
| 189 |
+
f'{confidence:.2f}', fill="red")
|
| 190 |
+
|
| 191 |
+
# Draw landmarks
|
| 192 |
+
for landmark_name, (x, y) in landmarks.items():
|
| 193 |
+
draw.ellipse([x-2, y-2, x+2, y+2], fill="blue")
|
| 194 |
+
|
| 195 |
+
return image
|
| 196 |
+
|
| 197 |
+
def gradio_detect_faces(image, model_type, confidence_threshold, nms_threshold):
|
| 198 |
+
"""Gradio interface function for face detection"""
|
| 199 |
+
if mobilenet_model is None or resnet_model is None:
|
| 200 |
+
return None, "❌ Models not loaded. Please wait for models to load.", ""
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
# Convert PIL to OpenCV format
|
| 204 |
+
if isinstance(image, Image.Image):
|
| 205 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 206 |
+
|
| 207 |
+
# Select model
|
| 208 |
+
if model_type.lower() == "resnet":
|
| 209 |
+
model = resnet_model
|
| 210 |
+
cfg = resnet_cfg
|
| 211 |
+
model_name = "ResNet50"
|
| 212 |
+
else:
|
| 213 |
+
model = mobilenet_model
|
| 214 |
+
cfg = mobilenet_cfg
|
| 215 |
+
model_name = "MobileNet"
|
| 216 |
+
|
| 217 |
+
# Detect faces
|
| 218 |
+
faces, processing_time = detect_faces_core(
|
| 219 |
+
image, model, cfg, confidence_threshold, nms_threshold
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Draw results on image
|
| 223 |
+
result_image = draw_faces_on_image(image.copy(), faces)
|
| 224 |
+
|
| 225 |
+
# Create results text
|
| 226 |
+
results_text = f"🎯 **Detection Results**\n"
|
| 227 |
+
results_text += f"📱 Model: {model_name}\n"
|
| 228 |
+
results_text += f"⏱️ Processing Time: {processing_time:.3f}s\n"
|
| 229 |
+
results_text += f"👥 Faces Detected: {len(faces)}\n\n"
|
| 230 |
+
|
| 231 |
+
for i, face in enumerate(faces):
|
| 232 |
+
results_text += f"**Face {i+1}:**\n"
|
| 233 |
+
results_text += f" Confidence: {face['confidence']:.3f}\n"
|
| 234 |
+
bbox = face['bbox']
|
| 235 |
+
results_text += f" Location: ({bbox['x1']:.0f}, {bbox['y1']:.0f}) - ({bbox['x2']:.0f}, {bbox['y2']:.0f})\n\n"
|
| 236 |
+
|
| 237 |
+
# Create JSON output for API use
|
| 238 |
+
json_output = {
|
| 239 |
+
"faces": faces,
|
| 240 |
+
"processing_time": processing_time,
|
| 241 |
+
"model_used": model_name.lower(),
|
| 242 |
+
"total_faces": len(faces)
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
return result_image, results_text, json.dumps(json_output, indent=2)
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
error_msg = f"❌ Detection failed: {str(e)}"
|
| 249 |
+
return None, error_msg, ""
|
| 250 |
+
|
| 251 |
+
def api_detect_live(image_base64, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
|
| 252 |
+
"""API function for live detection (Thunkable compatible)"""
|
| 253 |
+
try:
|
| 254 |
+
# Decode base64 image
|
| 255 |
+
image_data = base64.b64decode(image_base64)
|
| 256 |
+
nparr = np.frombuffer(image_data, np.uint8)
|
| 257 |
+
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 258 |
+
|
| 259 |
+
if image is None:
|
| 260 |
+
return {"error": "Invalid image data"}
|
| 261 |
+
|
| 262 |
+
# Select model
|
| 263 |
+
if model_type.lower() == "resnet":
|
| 264 |
+
model = resnet_model
|
| 265 |
+
cfg = resnet_cfg
|
| 266 |
+
model_name = "resnet"
|
| 267 |
+
else:
|
| 268 |
+
model = mobilenet_model
|
| 269 |
+
cfg = mobilenet_cfg
|
| 270 |
+
model_name = "mobilenet"
|
| 271 |
+
|
| 272 |
+
if model is None:
|
| 273 |
+
return {"error": f"{model_name} model not loaded"}
|
| 274 |
+
|
| 275 |
+
# Detect faces
|
| 276 |
+
faces, processing_time = detect_faces_core(
|
| 277 |
+
image, model, cfg, confidence_threshold, nms_threshold
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"faces": faces,
|
| 282 |
+
"processing_time": processing_time,
|
| 283 |
+
"model_used": model_name,
|
| 284 |
+
"total_faces": len(faces)
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
return {"error": f"Detection failed: {str(e)}"}
|
| 289 |
+
|
| 290 |
+
# Load models on startup
|
| 291 |
+
print("Loading RetinaFace models...")
|
| 292 |
+
load_status = load_models()
|
| 293 |
+
|
| 294 |
+
# Create Gradio interface
|
| 295 |
+
with gr.Blocks(title="RetinaFace Face Detection API", theme=gr.themes.Soft()) as demo:
|
| 296 |
+
gr.Markdown("""
|
| 297 |
+
# 🔥 RetinaFace Face Detection API
|
| 298 |
+
|
| 299 |
+
**Real-time face detection using RetinaFace with MobileNet and ResNet backbones**
|
| 300 |
+
|
| 301 |
+
- 📱 **Mobile App Ready**: Compatible with Thunkable and other mobile frameworks
|
| 302 |
+
- ⚡ **Dual Models**: MobileNet (fast) and ResNet (accurate)
|
| 303 |
+
- 🎯 **High Accuracy**: Detects faces with bounding boxes and 5-point landmarks
|
| 304 |
+
- 🌐 **API Endpoints**: Use `/api/predict` for programmatic access
|
| 305 |
+
""")
|
| 306 |
+
|
| 307 |
+
with gr.Row():
|
| 308 |
+
gr.Markdown(f"**Status**: {load_status}")
|
| 309 |
+
|
| 310 |
+
with gr.Tab("🖼️ Image Detection"):
|
| 311 |
+
with gr.Row():
|
| 312 |
+
with gr.Column():
|
| 313 |
+
input_image = gr.Image(type="pil", label="Upload Image")
|
| 314 |
+
model_choice = gr.Dropdown(
|
| 315 |
+
choices=["mobilenet", "resnet"],
|
| 316 |
+
value="mobilenet",
|
| 317 |
+
label="Model Type"
|
| 318 |
+
)
|
| 319 |
+
confidence_slider = gr.Slider(
|
| 320 |
+
minimum=0.1, maximum=1.0, value=0.5, step=0.1,
|
| 321 |
+
label="Confidence Threshold"
|
| 322 |
+
)
|
| 323 |
+
nms_slider = gr.Slider(
|
| 324 |
+
minimum=0.1, maximum=1.0, value=0.4, step=0.1,
|
| 325 |
+
label="NMS Threshold"
|
| 326 |
+
)
|
| 327 |
+
detect_btn = gr.Button("🔍 Detect Faces", variant="primary")
|
| 328 |
+
|
| 329 |
+
with gr.Column():
|
| 330 |
+
output_image = gr.Image(label="Detection Results")
|
| 331 |
+
results_text = gr.Markdown(label="Results")
|
| 332 |
+
|
| 333 |
+
detect_btn.click(
|
| 334 |
+
fn=gradio_detect_faces,
|
| 335 |
+
inputs=[input_image, model_choice, confidence_slider, nms_slider],
|
| 336 |
+
outputs=[output_image, results_text]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
with gr.Tab("🔗 API Documentation"):
|
| 340 |
+
gr.Markdown("""
|
| 341 |
+
## API Endpoints for Thunkable Integration
|
| 342 |
+
|
| 343 |
+
### 1. Live Detection Endpoint
|
| 344 |
+
```
|
| 345 |
+
POST /api/predict
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
**Request Body (JSON):**
|
| 349 |
+
```json
|
| 350 |
+
{
|
| 351 |
+
"data": [
|
| 352 |
+
"base64_encoded_image_string",
|
| 353 |
+
"mobilenet",
|
| 354 |
+
0.5,
|
| 355 |
+
0.4
|
| 356 |
+
]
|
| 357 |
+
}
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
**Response:**
|
| 361 |
+
```json
|
| 362 |
+
{
|
| 363 |
+
"data": [
|
| 364 |
+
{
|
| 365 |
+
"faces": [...],
|
| 366 |
+
"processing_time": 0.1,
|
| 367 |
+
"model_used": "mobilenet",
|
| 368 |
+
"total_faces": 2
|
| 369 |
+
}
|
| 370 |
+
]
|
| 371 |
+
}
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
### 2. Thunkable Integration Example
|
| 375 |
+
|
| 376 |
+
**Web API Component Setup:**
|
| 377 |
+
- URL: `https://your-space-name.hf.space/api/predict`
|
| 378 |
+
- Method: `POST`
|
| 379 |
+
- Headers: `Content-Type: application/json`
|
| 380 |
+
- Body:
|
| 381 |
+
```json
|
| 382 |
+
{
|
| 383 |
+
"data": [
|
| 384 |
+
"{{base64_image}}",
|
| 385 |
+
"mobilenet",
|
| 386 |
+
0.5,
|
| 387 |
+
0.4
|
| 388 |
+
]
|
| 389 |
+
}
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
### 3. Model Performance
|
| 393 |
+
|
| 394 |
+
| Model | Speed | Accuracy | Best For |
|
| 395 |
+
|-------|-------|----------|----------|
|
| 396 |
+
| MobileNet | ⚡ Fast | 🎯 Good | Real-time mobile apps |
|
| 397 |
+
| ResNet50 | 🐌 Slower | 🎯🎯 High | High-accuracy applications |
|
| 398 |
+
|
| 399 |
+
### 4. Response Format
|
| 400 |
+
|
| 401 |
+
Each detected face includes:
|
| 402 |
+
- **bbox**: Bounding box coordinates (x1, y1, x2, y2)
|
| 403 |
+
- **confidence**: Detection confidence score (0-1)
|
| 404 |
+
- **landmarks**: 5-point facial landmarks (eyes, nose, mouth corners)
|
| 405 |
+
""")
|
| 406 |
+
|
| 407 |
+
with gr.Tab("📊 API Testing"):
|
| 408 |
+
gr.Markdown("### Test the API with base64 encoded images")
|
| 409 |
+
|
| 410 |
+
with gr.Row():
|
| 411 |
+
with gr.Column():
|
| 412 |
+
test_image_b64 = gr.Textbox(
|
| 413 |
+
label="Base64 Encoded Image",
|
| 414 |
+
placeholder="Paste base64 encoded image here...",
|
| 415 |
+
lines=3
|
| 416 |
+
)
|
| 417 |
+
test_model = gr.Dropdown(
|
| 418 |
+
choices=["mobilenet", "resnet"],
|
| 419 |
+
value="mobilenet",
|
| 420 |
+
label="Model"
|
| 421 |
+
)
|
| 422 |
+
test_conf = gr.Number(value=0.5, label="Confidence")
|
| 423 |
+
test_nms = gr.Number(value=0.4, label="NMS Threshold")
|
| 424 |
+
test_btn = gr.Button("🧪 Test API", variant="secondary")
|
| 425 |
+
|
| 426 |
+
with gr.Column():
|
| 427 |
+
api_output = gr.JSON(label="API Response")
|
| 428 |
+
|
| 429 |
+
def test_api_function(image_b64, model, conf, nms):
|
| 430 |
+
if not image_b64.strip():
|
| 431 |
+
return {"error": "Please provide base64 encoded image"}
|
| 432 |
+
|
| 433 |
+
# Remove data URL prefix if present
|
| 434 |
+
if image_b64.startswith('data:image'):
|
| 435 |
+
image_b64 = image_b64.split(',')[1]
|
| 436 |
+
|
| 437 |
+
result = api_detect_live(image_b64, model, conf, nms)
|
| 438 |
+
return result
|
| 439 |
+
|
| 440 |
+
test_btn.click(
|
| 441 |
+
fn=test_api_function,
|
| 442 |
+
inputs=[test_image_b64, test_model, test_conf, test_nms],
|
| 443 |
+
outputs=[api_output]
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Custom API function for external calls
|
| 447 |
+
def predict_api(image_base64, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
|
| 448 |
+
"""API prediction function that matches Gradio's expected format"""
|
| 449 |
+
result = api_detect_live(image_base64, model_type, confidence_threshold, nms_threshold)
|
| 450 |
+
return result
|
| 451 |
+
|
| 452 |
+
# Launch the app
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
demo.launch(
|
| 455 |
+
server_name="0.0.0.0",
|
| 456 |
+
server_port=7860,
|
| 457 |
+
share=False
|
| 458 |
+
)
|
mobilenet0.25_Final.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2979b33ffafda5d74b6948cd7a5b9a7a62f62b949cef24e95fd15d2883a65220
|
| 3 |
+
size 1789735
|
models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Initialize empty __init__.py files for proper module imports
|
| 3 |
+
"""
|
models/retinaface.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
def conv_bn(inp, oup, stride=1, leaky=0):
|
| 9 |
+
return nn.Sequential(
|
| 10 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 11 |
+
nn.BatchNorm2d(oup),
|
| 12 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def conv_bn_no_relu(inp, oup, stride):
|
| 16 |
+
return nn.Sequential(
|
| 17 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 18 |
+
nn.BatchNorm2d(oup),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def conv_bn1X1(inp, oup, stride, leaky=0):
|
| 22 |
+
return nn.Sequential(
|
| 23 |
+
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
| 24 |
+
nn.BatchNorm2d(oup),
|
| 25 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def conv_dw(inp, oup, stride, leaky=0.1):
|
| 29 |
+
return nn.Sequential(
|
| 30 |
+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
| 31 |
+
nn.BatchNorm2d(inp),
|
| 32 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
| 33 |
+
|
| 34 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
| 35 |
+
nn.BatchNorm2d(oup),
|
| 36 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
class SSH(nn.Module):
|
| 40 |
+
def __init__(self, in_channel, out_channel):
|
| 41 |
+
super(SSH, self).__init__()
|
| 42 |
+
assert out_channel % 4 == 0
|
| 43 |
+
leaky = 0
|
| 44 |
+
if (out_channel <= 64):
|
| 45 |
+
leaky = 0.1
|
| 46 |
+
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
|
| 47 |
+
|
| 48 |
+
self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
|
| 49 |
+
self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
|
| 50 |
+
|
| 51 |
+
self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
|
| 52 |
+
self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
|
| 53 |
+
|
| 54 |
+
def forward(self, input):
|
| 55 |
+
conv3X3 = self.conv3X3(input)
|
| 56 |
+
|
| 57 |
+
conv5X5_1 = self.conv5X5_1(input)
|
| 58 |
+
conv5X5 = self.conv5X5_2(conv5X5_1)
|
| 59 |
+
|
| 60 |
+
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
| 61 |
+
conv7X7 = self.conv7x7_3(conv7X7_2)
|
| 62 |
+
|
| 63 |
+
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
| 64 |
+
out = F.relu(out)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
class FPN(nn.Module):
|
| 68 |
+
def __init__(self,in_channels_list,out_channels):
|
| 69 |
+
super(FPN,self).__init__()
|
| 70 |
+
leaky = 0
|
| 71 |
+
if (out_channels <= 64):
|
| 72 |
+
leaky = 0.1
|
| 73 |
+
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
|
| 74 |
+
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
|
| 75 |
+
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
|
| 76 |
+
|
| 77 |
+
self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
|
| 78 |
+
self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
|
| 79 |
+
|
| 80 |
+
def forward(self, input):
|
| 81 |
+
# names = list(input.keys())
|
| 82 |
+
input = list(input.values())
|
| 83 |
+
|
| 84 |
+
output1 = self.output1(input[0])
|
| 85 |
+
output2 = self.output2(input[1])
|
| 86 |
+
output3 = self.output3(input[2])
|
| 87 |
+
|
| 88 |
+
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
|
| 89 |
+
output2 = output2 + up3
|
| 90 |
+
output2 = self.merge2(output2)
|
| 91 |
+
|
| 92 |
+
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
|
| 93 |
+
output1 = output1 + up2
|
| 94 |
+
output1 = self.merge1(output1)
|
| 95 |
+
|
| 96 |
+
out = [output1, output2, output3]
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
class MobileNetV1(nn.Module):
|
| 100 |
+
def __init__(self):
|
| 101 |
+
super(MobileNetV1, self).__init__()
|
| 102 |
+
self.stage1 = nn.Sequential(
|
| 103 |
+
conv_bn(3, 8, 2, leaky = 0.1), # 3
|
| 104 |
+
conv_dw(8, 16, 1), # 7
|
| 105 |
+
conv_dw(16, 32, 2), # 11
|
| 106 |
+
conv_dw(32, 32, 1), # 19
|
| 107 |
+
conv_dw(32, 64, 2), # 27
|
| 108 |
+
conv_dw(64, 64, 1), # 43
|
| 109 |
+
)
|
| 110 |
+
self.stage2 = nn.Sequential(
|
| 111 |
+
conv_dw(64, 128, 2), # 43 + 16 = 59
|
| 112 |
+
conv_dw(128, 128, 1), # 59 + 32 = 91
|
| 113 |
+
conv_dw(128, 128, 1), # 91 + 32 = 123
|
| 114 |
+
conv_dw(128, 128, 1), # 123 + 32 = 155
|
| 115 |
+
conv_dw(128, 128, 1), # 155 + 32 = 187
|
| 116 |
+
conv_dw(128, 128, 1), # 187 + 32 = 219
|
| 117 |
+
)
|
| 118 |
+
self.stage3 = nn.Sequential(
|
| 119 |
+
conv_dw(128, 256, 2), # 219 + 32 = 251
|
| 120 |
+
conv_dw(256, 256, 1), # 251 + 64 = 315
|
| 121 |
+
)
|
| 122 |
+
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
| 123 |
+
self.fc = nn.Linear(256, 1000)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = self.stage1(x)
|
| 127 |
+
x = self.stage2(x)
|
| 128 |
+
x = self.stage3(x)
|
| 129 |
+
x = self.avg(x)
|
| 130 |
+
# x = self.model(x)
|
| 131 |
+
x = x.view(-1, 256)
|
| 132 |
+
x = self.fc(x)
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
class ClassHead(nn.Module):
|
| 136 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 137 |
+
super(ClassHead,self).__init__()
|
| 138 |
+
self.num_anchors = num_anchors
|
| 139 |
+
self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
|
| 140 |
+
|
| 141 |
+
def forward(self,x):
|
| 142 |
+
out = self.conv1x1(x)
|
| 143 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 144 |
+
|
| 145 |
+
return out.view(out.shape[0], -1, 2)
|
| 146 |
+
|
| 147 |
+
class BboxHead(nn.Module):
|
| 148 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 149 |
+
super(BboxHead,self).__init__()
|
| 150 |
+
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
|
| 151 |
+
|
| 152 |
+
def forward(self,x):
|
| 153 |
+
out = self.conv1x1(x)
|
| 154 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 155 |
+
|
| 156 |
+
return out.view(out.shape[0], -1, 4)
|
| 157 |
+
|
| 158 |
+
class LandmarkHead(nn.Module):
|
| 159 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 160 |
+
super(LandmarkHead,self).__init__()
|
| 161 |
+
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
|
| 162 |
+
|
| 163 |
+
def forward(self,x):
|
| 164 |
+
out = self.conv1x1(x)
|
| 165 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 166 |
+
|
| 167 |
+
return out.view(out.shape[0], -1, 10)
|
| 168 |
+
|
| 169 |
+
class RetinaFace(nn.Module):
|
| 170 |
+
def __init__(self, cfg = None, phase = 'train'):
|
| 171 |
+
"""
|
| 172 |
+
:param cfg: Network related settings.
|
| 173 |
+
:param phase: train or test.
|
| 174 |
+
"""
|
| 175 |
+
super(RetinaFace,self).__init__()
|
| 176 |
+
self.phase = phase
|
| 177 |
+
backbone = None
|
| 178 |
+
if cfg['name'] == 'mobilenet0.25':
|
| 179 |
+
backbone = MobileNetV1()
|
| 180 |
+
if cfg['pretrain']:
|
| 181 |
+
checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
|
| 182 |
+
from collections import OrderedDict
|
| 183 |
+
new_state_dict = OrderedDict()
|
| 184 |
+
for k, v in checkpoint['state_dict'].items():
|
| 185 |
+
name = k[7:] # remove module.
|
| 186 |
+
new_state_dict[name] = v
|
| 187 |
+
# load params
|
| 188 |
+
backbone.load_state_dict(new_state_dict)
|
| 189 |
+
elif cfg['name'] == 'Resnet50':
|
| 190 |
+
import torchvision.models as models
|
| 191 |
+
backbone = models.resnet50(pretrained=cfg['pretrain'])
|
| 192 |
+
|
| 193 |
+
if cfg['name'] == 'Resnet50':
|
| 194 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
| 195 |
+
self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
|
| 196 |
+
else:
|
| 197 |
+
self.body = backbone
|
| 198 |
+
|
| 199 |
+
in_channels_stage2 = cfg['in_channel']
|
| 200 |
+
in_channels_list = [
|
| 201 |
+
in_channels_stage2 * 2,
|
| 202 |
+
in_channels_stage2 * 4,
|
| 203 |
+
in_channels_stage2 * 8,
|
| 204 |
+
]
|
| 205 |
+
out_channels = cfg['out_channel']
|
| 206 |
+
self.fpn = FPN(in_channels_list,out_channels)
|
| 207 |
+
self.ssh1 = SSH(out_channels, out_channels)
|
| 208 |
+
self.ssh2 = SSH(out_channels, out_channels)
|
| 209 |
+
self.ssh3 = SSH(out_channels, out_channels)
|
| 210 |
+
|
| 211 |
+
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 212 |
+
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 213 |
+
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 214 |
+
|
| 215 |
+
def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 216 |
+
classhead = nn.ModuleList()
|
| 217 |
+
for i in range(fpn_num):
|
| 218 |
+
classhead.append(ClassHead(inchannels,anchor_num))
|
| 219 |
+
return classhead
|
| 220 |
+
|
| 221 |
+
def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 222 |
+
bboxhead = nn.ModuleList()
|
| 223 |
+
for i in range(fpn_num):
|
| 224 |
+
bboxhead.append(BboxHead(inchannels,anchor_num))
|
| 225 |
+
return bboxhead
|
| 226 |
+
|
| 227 |
+
def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 228 |
+
landmarkhead = nn.ModuleList()
|
| 229 |
+
for i in range(fpn_num):
|
| 230 |
+
landmarkhead.append(LandmarkHead(inchannels,anchor_num))
|
| 231 |
+
return landmarkhead
|
| 232 |
+
|
| 233 |
+
def forward(self,inputs):
|
| 234 |
+
out = self.body(inputs)
|
| 235 |
+
|
| 236 |
+
# FPN
|
| 237 |
+
fpn = self.fpn(out)
|
| 238 |
+
|
| 239 |
+
# SSH
|
| 240 |
+
feature1 = self.ssh1(fpn[0])
|
| 241 |
+
feature2 = self.ssh2(fpn[1])
|
| 242 |
+
feature3 = self.ssh3(fpn[2])
|
| 243 |
+
features = [feature1, feature2, feature3]
|
| 244 |
+
|
| 245 |
+
bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 246 |
+
classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 247 |
+
ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 248 |
+
|
| 249 |
+
if self.phase == 'train':
|
| 250 |
+
output = (bbox_regressions, classifications, ldm_regressions)
|
| 251 |
+
else:
|
| 252 |
+
output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
|
| 253 |
+
return output
|
| 254 |
+
|
| 255 |
+
# Utils for ResNet backbone
|
| 256 |
+
class _utils_resnet:
|
| 257 |
+
class IntermediateLayerGetter(nn.ModuleDict):
|
| 258 |
+
"""
|
| 259 |
+
Module wrapper that returns intermediate layers from a model
|
| 260 |
+
|
| 261 |
+
It has a strong assumption that the modules have been registered
|
| 262 |
+
into the model in the same order as they are used.
|
| 263 |
+
This means that one should **not** reuse the same nn.Module
|
| 264 |
+
twice in the forward if you want this to work.
|
| 265 |
+
|
| 266 |
+
Additionally, it is only able to query submodules that are directly
|
| 267 |
+
assigned to the model. So if `model` is passed, `model.feature1` can
|
| 268 |
+
be returned, but not `model.feature1.layer2`.
|
| 269 |
+
|
| 270 |
+
Arguments:
|
| 271 |
+
model (nn.Module): model on which we will extract the features
|
| 272 |
+
return_layers (Dict[name, new_name]): a dict containing the names
|
| 273 |
+
of the modules for which the activations will be returned as
|
| 274 |
+
the key of the dict, and the value of the dict is the name
|
| 275 |
+
of the returned activation (which the user can specify).
|
| 276 |
+
|
| 277 |
+
Examples::
|
| 278 |
+
|
| 279 |
+
>>> m = torchvision.models.resnet18(pretrained=True)
|
| 280 |
+
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
| 281 |
+
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
|
| 282 |
+
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
| 283 |
+
>>> out = new_m(x)
|
| 284 |
+
>>> print([(k, v.shape) for k, v in out.items()])
|
| 285 |
+
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
| 286 |
+
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
| 287 |
+
"""
|
| 288 |
+
_version = 2
|
| 289 |
+
__annotations__ = {
|
| 290 |
+
"return_layers": Dict[str, str],
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
def __init__(self, model, return_layers):
|
| 294 |
+
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
| 295 |
+
raise ValueError("return_layers are not present in model")
|
| 296 |
+
orig_return_layers = return_layers
|
| 297 |
+
return_layers = {str(k): str(v) for k, v in return_layers.items()}
|
| 298 |
+
layers = OrderedDict()
|
| 299 |
+
for name, module in model.named_children():
|
| 300 |
+
layers[name] = module
|
| 301 |
+
if name in return_layers:
|
| 302 |
+
del return_layers[name]
|
| 303 |
+
if not return_layers:
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
super(_utils_resnet.IntermediateLayerGetter, self).__init__(layers)
|
| 307 |
+
self.return_layers = orig_return_layers
|
| 308 |
+
|
| 309 |
+
def forward(self, x):
|
| 310 |
+
result = OrderedDict()
|
| 311 |
+
for name, module in self.items():
|
| 312 |
+
x = module(x)
|
| 313 |
+
if name in self.return_layers:
|
| 314 |
+
out_name = self.return_layers[name]
|
| 315 |
+
result[out_name] = x
|
| 316 |
+
return result
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
torch==2.0.1
|
| 3 |
+
torchvision==0.15.2
|
| 4 |
+
opencv-python==4.8.1.78
|
| 5 |
+
numpy==1.24.3
|
| 6 |
+
Pillow==10.0.1
|
| 7 |
+
fastapi==0.104.1
|
| 8 |
+
uvicorn[standard]==0.24.0
|
| 9 |
+
python-multipart==0.0.6
|
| 10 |
+
pydantic==2.4.2
|
start.bat
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo Starting RetinaFace Gradio API...
|
| 3 |
+
|
| 4 |
+
REM Check if model files exist
|
| 5 |
+
if not exist "mobilenet0.25_Final.pth" (
|
| 6 |
+
echo Warning: mobilenet0.25_Final.pth not found!
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
if not exist "Resnet50_Final.pth" (
|
| 10 |
+
echo Warning: Resnet50_Final.pth not found!
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
REM Start the Gradio app
|
| 14 |
+
python app.py
|
| 15 |
+
|
| 16 |
+
pause
|
start.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo "Starting RetinaFace Face Detection API..."
|
| 4 |
+
|
| 5 |
+
# Check if model files exist
|
| 6 |
+
if [ ! -f "mobilenet0.25_Final.pth" ]; then
|
| 7 |
+
echo "Warning: mobilenet0.25_Final.pth not found!"
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
if [ ! -f "Resnet50_Final.pth" ]; then
|
| 11 |
+
echo "Warning: Resnet50_Final.pth not found!"
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
# Start the FastAPI server
|
| 15 |
+
uvicorn app:app --host 0.0.0.0 --port 7860 --reload
|
test_api.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import base64
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
def test_gradio_api():
|
| 6 |
+
"""Test the Gradio /api/predict endpoint"""
|
| 7 |
+
# You would replace this with actual base64 encoded image data
|
| 8 |
+
sample_image_path = "test_image.jpg" # Replace with your test image
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
with open(sample_image_path, "rb") as image_file:
|
| 12 |
+
image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
|
| 13 |
+
except FileNotFoundError:
|
| 14 |
+
print("Please add a test_image.jpg file to test the API")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
url = "http://localhost:7860/api/predict"
|
| 18 |
+
|
| 19 |
+
payload = {
|
| 20 |
+
"data": [
|
| 21 |
+
image_base64,
|
| 22 |
+
"mobilenet",
|
| 23 |
+
0.5,
|
| 24 |
+
0.4
|
| 25 |
+
]
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
response = requests.post(url, json=payload)
|
| 29 |
+
|
| 30 |
+
if response.status_code == 200:
|
| 31 |
+
result = response.json()
|
| 32 |
+
print("Success!")
|
| 33 |
+
print(f"API Response: {json.dumps(result, indent=2)}")
|
| 34 |
+
|
| 35 |
+
# Extract the actual detection data
|
| 36 |
+
if "data" in result and len(result["data"]) > 0:
|
| 37 |
+
detection_data = result["data"][0]
|
| 38 |
+
if "faces" in detection_data:
|
| 39 |
+
print(f"Detected {len(detection_data['faces'])} faces")
|
| 40 |
+
print(f"Processing time: {detection_data.get('processing_time', 'N/A'):.3f} seconds")
|
| 41 |
+
print(f"Model used: {detection_data.get('model_used', 'N/A')}")
|
| 42 |
+
|
| 43 |
+
for i, face in enumerate(detection_data['faces']):
|
| 44 |
+
print(f"Face {i+1}:")
|
| 45 |
+
print(f" Confidence: {face['confidence']:.3f}")
|
| 46 |
+
print(f" Bounding box: {face['bbox']}")
|
| 47 |
+
print(f" Landmarks: {face['landmarks']}")
|
| 48 |
+
else:
|
| 49 |
+
print("No face detection data in response")
|
| 50 |
+
else:
|
| 51 |
+
print("Unexpected response format")
|
| 52 |
+
else:
|
| 53 |
+
print(f"Error: {response.status_code}")
|
| 54 |
+
print(response.text)
|
| 55 |
+
|
| 56 |
+
def test_health_check():
|
| 57 |
+
"""Test the Gradio app health"""
|
| 58 |
+
url = "http://localhost:7860/"
|
| 59 |
+
|
| 60 |
+
response = requests.get(url)
|
| 61 |
+
|
| 62 |
+
if response.status_code == 200:
|
| 63 |
+
print("Gradio app is running!")
|
| 64 |
+
print("You can access the web interface at: http://localhost:7860")
|
| 65 |
+
else:
|
| 66 |
+
print(f"Health check failed: {response.status_code}")
|
| 67 |
+
|
| 68 |
+
def test_direct_api_call():
|
| 69 |
+
"""Test direct API call format that Thunkable would use"""
|
| 70 |
+
sample_image_path = "test_image.jpg" # Replace with your test image
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
with open(sample_image_path, "rb") as image_file:
|
| 74 |
+
image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
|
| 75 |
+
except FileNotFoundError:
|
| 76 |
+
print("Please add a test_image.jpg file to test the API")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
url = "http://localhost:7860/api/predict"
|
| 80 |
+
|
| 81 |
+
# This is the format Thunkable will use
|
| 82 |
+
payload = {
|
| 83 |
+
"data": [image_base64, "mobilenet", 0.5, 0.4]
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
headers = {
|
| 87 |
+
"Content-Type": "application/json"
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
print("Testing Thunkable-compatible API call...")
|
| 91 |
+
response = requests.post(url, json=payload, headers=headers)
|
| 92 |
+
|
| 93 |
+
if response.status_code == 200:
|
| 94 |
+
result = response.json()
|
| 95 |
+
print("✅ Thunkable API call successful!")
|
| 96 |
+
|
| 97 |
+
# Parse the response as Thunkable would
|
| 98 |
+
if "data" in result and result["data"]:
|
| 99 |
+
detection_result = result["data"][0]
|
| 100 |
+
print(f"Faces detected: {detection_result.get('total_faces', 0)}")
|
| 101 |
+
print(f"Model used: {detection_result.get('model_used', 'unknown')}")
|
| 102 |
+
print(f"Processing time: {detection_result.get('processing_time', 0):.3f}s")
|
| 103 |
+
else:
|
| 104 |
+
print("❌ Unexpected response format")
|
| 105 |
+
else:
|
| 106 |
+
print(f"❌ API call failed: {response.status_code}")
|
| 107 |
+
print(response.text)
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
print("Testing RetinaFace Gradio API...")
|
| 111 |
+
print("=" * 50)
|
| 112 |
+
|
| 113 |
+
print("\n1. Health Check:")
|
| 114 |
+
test_health_check()
|
| 115 |
+
|
| 116 |
+
print("\n2. Gradio API Test:")
|
| 117 |
+
test_gradio_api()
|
| 118 |
+
|
| 119 |
+
print("\n3. Thunkable-Compatible API Test:")
|
| 120 |
+
test_direct_api_call()
|
| 121 |
+
|
| 122 |
+
print("\n" + "=" * 50)
|
| 123 |
+
print("Testing complete!")
|
| 124 |
+
print("\nFor Thunkable integration:")
|
| 125 |
+
print("- Use URL: http://localhost:7860/api/predict")
|
| 126 |
+
print("- Method: POST")
|
| 127 |
+
print("- Body format: {\"data\": [\"base64_image\", \"mobilenet\", 0.5, 0.4]}")
|
| 128 |
+
print("- Response will be in: response.data[0]")
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Initialize empty __init__.py files for proper module imports
|
| 3 |
+
"""
|
utils/box_utils.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def point_form(boxes):
|
| 5 |
+
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
|
| 6 |
+
representation for comparison to point form ground truth data.
|
| 7 |
+
Args:
|
| 8 |
+
boxes: (tensor) center-size default boxes from priorbox layers.
|
| 9 |
+
Return:
|
| 10 |
+
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
| 11 |
+
"""
|
| 12 |
+
return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
|
| 13 |
+
boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def center_size(boxes):
|
| 17 |
+
""" Convert prior_boxes to (cx, cy, w, h)
|
| 18 |
+
representation for comparison to center-size form ground truth data.
|
| 19 |
+
Args:
|
| 20 |
+
boxes: (tensor) point_form boxes
|
| 21 |
+
Return:
|
| 22 |
+
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
| 23 |
+
"""
|
| 24 |
+
return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
|
| 25 |
+
boxes[:, 2:] - boxes[:, :2], 1) # w, h
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def intersect(box_a, box_b):
|
| 29 |
+
""" We resize both tensors to [A,B,2] without new malloc:
|
| 30 |
+
[A,2] -> [A,1,2] -> [A,B,2]
|
| 31 |
+
[B,2] -> [1,B,2] -> [A,B,2]
|
| 32 |
+
Then we compute the area of intersect between box_a and box_b.
|
| 33 |
+
Args:
|
| 34 |
+
box_a: (tensor) bounding boxes, Shape: [A,4].
|
| 35 |
+
box_b: (tensor) bounding boxes, Shape: [B,4].
|
| 36 |
+
Return:
|
| 37 |
+
(tensor) intersection area, Shape: [A,B].
|
| 38 |
+
"""
|
| 39 |
+
A = box_a.size(0)
|
| 40 |
+
B = box_b.size(0)
|
| 41 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
|
| 42 |
+
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
| 43 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
|
| 44 |
+
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
| 45 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
| 46 |
+
return inter[:, :, 0] * inter[:, :, 1]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def jaccard(box_a, box_b):
|
| 50 |
+
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
| 51 |
+
is simply the intersection over union of two boxes. Here we operate on
|
| 52 |
+
ground truth boxes and default boxes.
|
| 53 |
+
E.g.:
|
| 54 |
+
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
| 55 |
+
Args:
|
| 56 |
+
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
|
| 57 |
+
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
| 58 |
+
Return:
|
| 59 |
+
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
|
| 60 |
+
"""
|
| 61 |
+
inter = intersect(box_a, box_b)
|
| 62 |
+
area_a = ((box_a[:, 2]-box_a[:, 0]) *
|
| 63 |
+
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
| 64 |
+
area_b = ((box_b[:, 2]-box_b[:, 0]) *
|
| 65 |
+
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
| 66 |
+
union = area_a + area_b - inter
|
| 67 |
+
return inter / union # [A,B]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def matrix_iou(a,b):
|
| 71 |
+
"""
|
| 72 |
+
return iou of a and b, numpy version for data augenmentation
|
| 73 |
+
"""
|
| 74 |
+
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
| 75 |
+
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
| 76 |
+
|
| 77 |
+
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
| 78 |
+
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
| 79 |
+
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
|
| 80 |
+
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def matrix_iof(a, b):
|
| 84 |
+
"""
|
| 85 |
+
return iof of a and b, numpy version for data augenmentation
|
| 86 |
+
"""
|
| 87 |
+
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
| 88 |
+
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
| 89 |
+
|
| 90 |
+
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
| 91 |
+
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
| 92 |
+
return area_i / np.maximum(area_a[:, np.newaxis], 1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
|
| 96 |
+
"""Match each prior box with the ground truth box of the highest jaccard
|
| 97 |
+
overlap, encode the bounding boxes, then return the matched indices
|
| 98 |
+
corresponding to both confidence and location preds.
|
| 99 |
+
Args:
|
| 100 |
+
threshold: (float) The overlap threshold used when mathing boxes.
|
| 101 |
+
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
|
| 102 |
+
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
| 103 |
+
variances: (tensor) Variances corresponding to each prior coord,
|
| 104 |
+
Shape: [num_priors, 4].
|
| 105 |
+
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
| 106 |
+
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
|
| 107 |
+
loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
|
| 108 |
+
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
| 109 |
+
landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
|
| 110 |
+
idx: (int) current batch index
|
| 111 |
+
Return:
|
| 112 |
+
The matched indices corresponding to 1)location 2)confidence 3)landm preds.
|
| 113 |
+
"""
|
| 114 |
+
# jaccard index
|
| 115 |
+
overlaps = jaccard(
|
| 116 |
+
truths,
|
| 117 |
+
point_form(priors)
|
| 118 |
+
)
|
| 119 |
+
# (Bipartite Matching)
|
| 120 |
+
# [1,num_objects] best prior for each ground truth
|
| 121 |
+
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
| 122 |
+
|
| 123 |
+
# ignore hard gt
|
| 124 |
+
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
| 125 |
+
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
| 126 |
+
if best_prior_idx_filter.shape[0] <= 0:
|
| 127 |
+
loc_t[idx] = 0
|
| 128 |
+
conf_t[idx] = 0
|
| 129 |
+
return
|
| 130 |
+
|
| 131 |
+
# [1,num_priors] best ground truth for each prior
|
| 132 |
+
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
|
| 133 |
+
best_truth_idx.squeeze_(0)
|
| 134 |
+
best_truth_overlap.squeeze_(0)
|
| 135 |
+
best_prior_idx.squeeze_(1)
|
| 136 |
+
best_prior_idx_filter.squeeze_(1)
|
| 137 |
+
best_prior_overlap.squeeze_(1)
|
| 138 |
+
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
|
| 139 |
+
# TODO refactor: index best_prior_idx with long tensor
|
| 140 |
+
# ensure every gt matches with its prior of max overlap
|
| 141 |
+
for j in range(best_prior_idx.size(0)): # 判别此anchor是否与某个ground truth匹配
|
| 142 |
+
best_truth_idx[best_prior_idx[j]] = j
|
| 143 |
+
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每个anchor都分配一个gt
|
| 144 |
+
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每个anchor都分配一个label
|
| 145 |
+
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
|
| 146 |
+
loc = encode(matches, priors, variances)
|
| 147 |
+
|
| 148 |
+
matches_landm = landms[best_truth_idx]
|
| 149 |
+
landm = encode_landm(matches_landm, priors, variances)
|
| 150 |
+
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
| 151 |
+
conf_t[idx] = conf # [num_priors] top class label for each prior
|
| 152 |
+
landm_t[idx] = landm
|
| 153 |
+
|
| 154 |
+
def encode(matched, priors, variances):
|
| 155 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 156 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 157 |
+
Args:
|
| 158 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
| 159 |
+
Shape: [num_priors, 4].
|
| 160 |
+
priors: (tensor) Prior boxes in center-offset form
|
| 161 |
+
Shape: [num_priors,4].
|
| 162 |
+
variances: (list[float]) Variances of priorboxes
|
| 163 |
+
Return:
|
| 164 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
# dist b/t match center and prior's center
|
| 168 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
|
| 169 |
+
# encode variance
|
| 170 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
| 171 |
+
# match wh / prior wh
|
| 172 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
| 173 |
+
g_wh = torch.log(g_wh) / variances[1]
|
| 174 |
+
# return target for smooth_l1_loss
|
| 175 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
| 176 |
+
|
| 177 |
+
def encode_landm(matched, priors, variances):
|
| 178 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 179 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 180 |
+
Args:
|
| 181 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
| 182 |
+
Shape: [num_priors, 10].
|
| 183 |
+
priors: (tensor) Prior boxes in center-offset form
|
| 184 |
+
Shape: [num_priors,4].
|
| 185 |
+
variances: (list[float]) Variances of priorboxes
|
| 186 |
+
Return:
|
| 187 |
+
encoded landm (tensor), Shape: [num_priors, 10]
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
# dist b/t match center and prior's center
|
| 191 |
+
matched = torch.reshape(matched, (matched.size(0), 5, 2))
|
| 192 |
+
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 193 |
+
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 194 |
+
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 195 |
+
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 196 |
+
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
|
| 197 |
+
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
|
| 198 |
+
# encode variance
|
| 199 |
+
g_cxcy /= (variances[0] * priors[:, :, 2:])
|
| 200 |
+
# g_cxcy /= priors[:, :, 2:]
|
| 201 |
+
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
|
| 202 |
+
# return target for smooth_l1_loss
|
| 203 |
+
return g_cxcy
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
| 207 |
+
def decode(loc, priors, variances):
|
| 208 |
+
"""Decode locations from predictions using priors to undo
|
| 209 |
+
the encoding we did for offset regression at train time.
|
| 210 |
+
Args:
|
| 211 |
+
loc (tensor): location predictions for loc layers,
|
| 212 |
+
Shape: [num_priors,4]
|
| 213 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 214 |
+
Shape: [num_priors,4].
|
| 215 |
+
variances: (list[float]) Variances of priorboxes
|
| 216 |
+
Return:
|
| 217 |
+
decoded bounding box predictions
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
boxes = torch.cat((
|
| 221 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
| 222 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
| 223 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 224 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 225 |
+
return boxes
|
| 226 |
+
|
| 227 |
+
def decode_landm(pre, priors, variances):
|
| 228 |
+
"""Decode landm from predictions using priors to undo
|
| 229 |
+
the encoding we did for offset regression at train time.
|
| 230 |
+
Args:
|
| 231 |
+
pre (tensor): landm predictions for loc layers,
|
| 232 |
+
Shape: [num_priors,10]
|
| 233 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 234 |
+
Shape: [num_priors,4].
|
| 235 |
+
variances: (list[float]) Variances of priorboxes
|
| 236 |
+
Return:
|
| 237 |
+
decoded landm predictions
|
| 238 |
+
"""
|
| 239 |
+
landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
|
| 240 |
+
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
|
| 241 |
+
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
|
| 242 |
+
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
|
| 243 |
+
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
|
| 244 |
+
), dim=1)
|
| 245 |
+
return landms
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def log_sum_exp(x):
|
| 249 |
+
"""Utility function for computing log_sum_exp while determining
|
| 250 |
+
This will be used to determine unaveraged confidence loss across
|
| 251 |
+
all examples in a batch.
|
| 252 |
+
Args:
|
| 253 |
+
x (Variable(tensor)): conf_preds from conf layers
|
| 254 |
+
"""
|
| 255 |
+
x_max = x.data.max()
|
| 256 |
+
return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
|
utils/prior_box.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from itertools import product as product
|
| 3 |
+
import numpy as np
|
| 4 |
+
from math import ceil
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PriorBox(object):
|
| 8 |
+
def __init__(self, cfg, image_size=None, phase='train'):
|
| 9 |
+
super(PriorBox, self).__init__()
|
| 10 |
+
self.min_sizes = cfg['min_sizes']
|
| 11 |
+
self.steps = cfg['steps']
|
| 12 |
+
self.clip = cfg['clip']
|
| 13 |
+
self.image_size = image_size
|
| 14 |
+
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
|
| 15 |
+
self.name = "s"
|
| 16 |
+
|
| 17 |
+
def forward(self):
|
| 18 |
+
anchors = []
|
| 19 |
+
for k, f in enumerate(self.feature_maps):
|
| 20 |
+
min_sizes = self.min_sizes[k]
|
| 21 |
+
for i, j in product(range(f[0]), range(f[1])):
|
| 22 |
+
for min_size in min_sizes:
|
| 23 |
+
s_kx = min_size / self.image_size[1]
|
| 24 |
+
s_ky = min_size / self.image_size[0]
|
| 25 |
+
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
|
| 26 |
+
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
|
| 27 |
+
for cy, cx in product(dense_cy, dense_cx):
|
| 28 |
+
anchors += [cx, cy, s_kx, s_ky]
|
| 29 |
+
|
| 30 |
+
# back to torch land
|
| 31 |
+
output = torch.Tensor(anchors).view(-1, 4)
|
| 32 |
+
if self.clip:
|
| 33 |
+
output.clamp_(max=1, min=0)
|
| 34 |
+
return output
|
utils/py_cpu_nms.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def py_cpu_nms(dets, thresh):
|
| 4 |
+
"""Pure Python NMS baseline."""
|
| 5 |
+
x1 = dets[:, 0]
|
| 6 |
+
y1 = dets[:, 1]
|
| 7 |
+
x2 = dets[:, 2]
|
| 8 |
+
y2 = dets[:, 3]
|
| 9 |
+
scores = dets[:, 4]
|
| 10 |
+
|
| 11 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 12 |
+
order = scores.argsort()[::-1]
|
| 13 |
+
|
| 14 |
+
keep = []
|
| 15 |
+
while order.size > 0:
|
| 16 |
+
i = order[0]
|
| 17 |
+
keep.append(i)
|
| 18 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 19 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 20 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 21 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 22 |
+
|
| 23 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 24 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 25 |
+
inter = w * h
|
| 26 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 27 |
+
|
| 28 |
+
inds = np.where(ovr <= thresh)[0]
|
| 29 |
+
order = order[inds + 1]
|
| 30 |
+
|
| 31 |
+
return keep
|