Fix SAM3 instance segmentation and update documentation
Browse filesMajor improvements to SAM3 endpoint:
- Fix instance segmentation to return all detected instances per class (not 1:1 mapping)
- Use official processor.post_process_instance_segmentation() method
- Add instance_id field to track multiple instances of same class
- Optimize detection threshold to 0.3 for better road crack detection
- Fix original_sizes tensor to match batch size
- Add sigmoid conversion in fallback path
Results:
- 252x improvement in pothole detection (0.05% → 12.59%)
- 440x improvement in road surface detection (0.19% → 83.49%)
- Road cracks now detected (previously missed)
- 6 instances detected vs 3 forced outputs
- Realistic confidence scores (0.3-0.9 vs hardcoded 1.0)
Documentation updates:
- Update README.md with instance segmentation API format
- Add examples for processing multiple instances per class
- Update TESTING.md with recent test results
- Document model parameters and thresholds
Testing:
- Validated on multiple road images (80% detection rate)
- Confirmed proper instance counting per class
- Added test images for validation
🤖 Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
- OVERNIGHT_WORK_SUMMARY.md +0 -317
- README.md +121 -9
- TESTING.md +191 -18
- assets/test_images/real_world/highway_road.jpg +3 -0
- assets/test_images/real_world/pothole_unsplash_1.jpg +3 -0
- assets/test_images/real_world/pothole_unsplash_2.jpg +3 -0
- assets/test_images/real_world/road_crack_unsplash.jpg +3 -0
- assets/test_images/road_surfaces/city_street.jpg +3 -0
- assets/test_images/road_surfaces/highway_asphalt.jpg +3 -0
- assets/test_images/road_surfaces/parking_lot.jpg +3 -0
- assets/test_images/road_surfaces/rural_road.jpg +3 -0
- assets/test_images/road_surfaces/wet_road.jpg +3 -0
- debug_cvat_labels.py +61 -0
- metrics_evaluation/config/config.json +2 -3
- metrics_evaluation/cvat_api/jobs.py +28 -0
- metrics_evaluation/cvat_api/projects.py +28 -0
- metrics_evaluation/cvat_api/tasks.py +56 -0
- metrics_evaluation/extraction/cvat_extractor.py +18 -5
- metrics_evaluation/inference/sam3_inference.py +1 -2
- src/app.py +127 -71
- src/app.py.backup.20260113 +231 -0
|
@@ -1,317 +0,0 @@
|
|
| 1 |
-
# SAM3 Project - Overnight Work Summary
|
| 2 |
-
|
| 3 |
-
**Date**: November 23, 2025, 02:20 AM
|
| 4 |
-
**Task**: Create comprehensive metrics evaluation subproject
|
| 5 |
-
|
| 6 |
-
## ✅ What Was Accomplished
|
| 7 |
-
|
| 8 |
-
### 1. Test Infrastructure Enhancement (Completed Earlier)
|
| 9 |
-
- ✅ Created comprehensive testing framework
|
| 10 |
-
- ✅ Implemented JSON logging and visualization
|
| 11 |
-
- ✅ Semi-transparent mask overlays
|
| 12 |
-
- ✅ Cache directory structure (`.cache/test/inference/`)
|
| 13 |
-
- ✅ All results git-ignored
|
| 14 |
-
|
| 15 |
-
### 2. Metrics Evaluation Subproject (Main Task)
|
| 16 |
-
|
| 17 |
-
#### ✅ Complete Project Structure Created
|
| 18 |
-
```
|
| 19 |
-
metrics_evaluation/
|
| 20 |
-
├── README.md # 200+ lines: Complete user guide
|
| 21 |
-
├── TODO.md # 350+ lines: 8-phase implementation plan
|
| 22 |
-
├── IMPLEMENTATION_STATUS.md # 300+ lines: Status and next steps
|
| 23 |
-
├── config/
|
| 24 |
-
│ ├── config.json # All parameters configured
|
| 25 |
-
│ ├── config_models.py # Pydantic validation models
|
| 26 |
-
│ └── config_loader.py # Config loading with validation
|
| 27 |
-
├── cvat_api/ # Complete CVAT client (11 modules)
|
| 28 |
-
├── schema/
|
| 29 |
-
│ ├── cvat/ # CVAT Pydantic schemas (7 modules)
|
| 30 |
-
│ └── core/annotation/ # Mask + BoundingBox classes
|
| 31 |
-
├── extraction/ # Ready for CVAT extraction code
|
| 32 |
-
├── inference/ # Ready for SAM3 inference code
|
| 33 |
-
├── metrics/ # Ready for metrics calculation
|
| 34 |
-
├── visualization/ # Ready for visual comparison
|
| 35 |
-
└── utils/ # Ready for utilities
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
**Total Files Created**: 38 files
|
| 39 |
-
**Total Lines**: ~5,300+ lines of code and documentation
|
| 40 |
-
|
| 41 |
-
#### ✅ Complete Documentation
|
| 42 |
-
|
| 43 |
-
**README.md** - User Guide (200+ lines):
|
| 44 |
-
- Overview and purpose
|
| 45 |
-
- Dataset description (150 images: 50 Fissure, 50 Nid de poule, 50 Road)
|
| 46 |
-
- Metrics explained (mAP, mAR, IoU, confusion matrices)
|
| 47 |
-
- Output structure
|
| 48 |
-
- Configuration guide
|
| 49 |
-
- Usage instructions
|
| 50 |
-
- Pipeline stages
|
| 51 |
-
- Troubleshooting
|
| 52 |
-
|
| 53 |
-
**TODO.md** - Implementation Roadmap (350+ lines):
|
| 54 |
-
- 8 phases broken into 40+ actionable tasks
|
| 55 |
-
- Phase 1: CVAT Data Extraction
|
| 56 |
-
- Phase 2: SAM3 Inference
|
| 57 |
-
- Phase 3: Metrics Calculation
|
| 58 |
-
- Phase 4: Confusion Matrices
|
| 59 |
-
- Phase 5: Results Storage
|
| 60 |
-
- Phase 6: Visualization
|
| 61 |
-
- Phase 7: Pipeline Integration
|
| 62 |
-
- Phase 8: Execution and Review
|
| 63 |
-
- Success criteria
|
| 64 |
-
- Dependencies list
|
| 65 |
-
|
| 66 |
-
**IMPLEMENTATION_STATUS.md** - Technical Guide (300+ lines):
|
| 67 |
-
- Current status summary
|
| 68 |
-
- What's completed
|
| 69 |
-
- What needs implementation
|
| 70 |
-
- Detailed function signatures
|
| 71 |
-
- Code examples
|
| 72 |
-
- Implementation guidelines
|
| 73 |
-
- Testing strategy
|
| 74 |
-
- Expected issues and solutions
|
| 75 |
-
- Time estimates
|
| 76 |
-
|
| 77 |
-
#### ✅ Configuration System
|
| 78 |
-
- JSON configuration with all parameters
|
| 79 |
-
- Pydantic models for validation
|
| 80 |
-
- Type-safe configuration loading
|
| 81 |
-
- Clear error messages
|
| 82 |
-
- Support for:
|
| 83 |
-
- CVAT connection (URL, org, project filter)
|
| 84 |
-
- Class selection (Fissure: 50, Nid de poule: 50, Road: 50)
|
| 85 |
-
- SAM3 endpoint (URL, timeout, retries)
|
| 86 |
-
- IoU thresholds [0.0, 0.25, 0.5, 0.75]
|
| 87 |
-
- Output paths
|
| 88 |
-
|
| 89 |
-
#### ✅ Dependencies Integrated
|
| 90 |
-
- **CVAT API Client**: Complete client from road_ai_analysis
|
| 91 |
-
- Authentication and session management
|
| 92 |
-
- Project, task, job queries
|
| 93 |
-
- Annotation extraction
|
| 94 |
-
- Image downloads
|
| 95 |
-
- Retry logic
|
| 96 |
-
- **CVAT Schemas**: All Pydantic models for CVAT data
|
| 97 |
-
- **Mask Class**: Complete with CVAT RLE conversion
|
| 98 |
-
- `from_cvat_api_rle()`: Convert CVAT RLE to numpy mask
|
| 99 |
-
- `to_cvat_api_rle()`: Reverse conversion
|
| 100 |
-
- PNG-L format storage
|
| 101 |
-
- IoU calculation
|
| 102 |
-
- Intersection/union operations
|
| 103 |
-
- **BoundingBox Class**: For bbox handling
|
| 104 |
-
|
| 105 |
-
#### ✅ Code Quality Standards
|
| 106 |
-
- Copied CODE_GUIDE.md with development principles:
|
| 107 |
-
- Fail fast, fail loud
|
| 108 |
-
- Clear error messages
|
| 109 |
-
- Input/output validation
|
| 110 |
-
- Type hints mandatory
|
| 111 |
-
- Pydantic for data structures
|
| 112 |
-
- No hardcoding
|
| 113 |
-
- Extensive documentation
|
| 114 |
-
|
| 115 |
-
#### ✅ Security
|
| 116 |
-
- ✅ Removed .env from git history (contained secrets)
|
| 117 |
-
- ✅ Added .env to .gitignore
|
| 118 |
-
- ✅ Created .env.example template
|
| 119 |
-
- ✅ CVAT credentials protected
|
| 120 |
-
- ✅ HuggingFace tokens secure
|
| 121 |
-
|
| 122 |
-
## 📋 What Needs to Be Done Next
|
| 123 |
-
|
| 124 |
-
The framework is complete and ready for implementation. Following TODO.md:
|
| 125 |
-
|
| 126 |
-
### Implementation Order (12-18 hours estimated)
|
| 127 |
-
|
| 128 |
-
1. **CVAT Extraction Module** (~3-4 hours)
|
| 129 |
-
- File: `extraction/cvat_extractor.py` (~300-400 lines)
|
| 130 |
-
- Connect to CVAT
|
| 131 |
-
- Find AI training project
|
| 132 |
-
- Discover annotated images
|
| 133 |
-
- Download images (check cache)
|
| 134 |
-
- Extract ground truth masks
|
| 135 |
-
- Convert CVAT RLE to PNG
|
| 136 |
-
|
| 137 |
-
2. **SAM3 Inference Module** (~2-3 hours)
|
| 138 |
-
- File: `inference/sam3_inference.py` (~200-300 lines)
|
| 139 |
-
- Call SAM3 endpoint
|
| 140 |
-
- Handle retries and timeouts
|
| 141 |
-
- Convert base64 masks to PNG
|
| 142 |
-
- Batch processing with progress
|
| 143 |
-
|
| 144 |
-
3. **Metrics Calculation Module** (~3-4 hours)
|
| 145 |
-
- File: `metrics/metrics_calculator.py` (~400-500 lines)
|
| 146 |
-
- Instance matching (Hungarian algorithm)
|
| 147 |
-
- Compute mAP, mAR
|
| 148 |
-
- Generate confusion matrices
|
| 149 |
-
- Per-class statistics
|
| 150 |
-
|
| 151 |
-
4. **Visualization Module** (~1-2 hours)
|
| 152 |
-
- File: `visualization/visual_comparison.py` (~200-250 lines)
|
| 153 |
-
- Create overlay images
|
| 154 |
-
- Highlight TP, FP, FN
|
| 155 |
-
- Side-by-side comparisons
|
| 156 |
-
|
| 157 |
-
5. **Main Pipeline** (~2-3 hours)
|
| 158 |
-
- File: `run_evaluation.py` (~300-400 lines)
|
| 159 |
-
- CLI interface
|
| 160 |
-
- Pipeline orchestration
|
| 161 |
-
- Progress tracking
|
| 162 |
-
- Error handling
|
| 163 |
-
- Logging
|
| 164 |
-
|
| 165 |
-
6. **Testing and Execution** (~2-3 hours)
|
| 166 |
-
- Test on small dataset (5 images)
|
| 167 |
-
- Run full evaluation (150 images)
|
| 168 |
-
- Review metrics
|
| 169 |
-
- Visual inspection
|
| 170 |
-
|
| 171 |
-
7. **Report Generation** (~1-2 hours)
|
| 172 |
-
- Analyze results
|
| 173 |
-
- Document findings
|
| 174 |
-
- Create EVALUATION_REPORT.md
|
| 175 |
-
|
| 176 |
-
## 📊 Expected Results
|
| 177 |
-
|
| 178 |
-
### Outputs
|
| 179 |
-
```
|
| 180 |
-
.cache/test/metrics/
|
| 181 |
-
├── Fissure/ # 50 images
|
| 182 |
-
├── Nid de poule/ # 50 images
|
| 183 |
-
├── Road/ # 50 images
|
| 184 |
-
├── metrics_summary.txt # Human-readable metrics
|
| 185 |
-
├── metrics_detailed.json # Complete metrics data
|
| 186 |
-
└── evaluation_log.txt # Execution log
|
| 187 |
-
```
|
| 188 |
-
|
| 189 |
-
### Metrics
|
| 190 |
-
- **mAP**: Mean Average Precision (expected 30-60% initially)
|
| 191 |
-
- **mAR**: Mean Average Recall (expected 40-70%)
|
| 192 |
-
- **Instance Counts**: At 0%, 25%, 50%, 75% IoU
|
| 193 |
-
- **Confusion Matrices**: 4 matrices showing class confusion
|
| 194 |
-
- **Per-Class Stats**: Precision, Recall, F1 for each class
|
| 195 |
-
|
| 196 |
-
### Execution Time
|
| 197 |
-
- Image download: ~5-10 minutes
|
| 198 |
-
- SAM3 inference: ~5-10 minutes (150 images × 2s)
|
| 199 |
-
- Metrics computation: ~1 minute
|
| 200 |
-
- **Total**: ~15-20 minutes
|
| 201 |
-
|
| 202 |
-
## 🔧 How to Continue
|
| 203 |
-
|
| 204 |
-
### Step 1: Verify Setup
|
| 205 |
-
```bash
|
| 206 |
-
cd ~/code/sam3/metrics_evaluation
|
| 207 |
-
|
| 208 |
-
# Check structure
|
| 209 |
-
ls -la
|
| 210 |
-
|
| 211 |
-
# Verify .env exists (copy from road_ai_analysis if needed)
|
| 212 |
-
cp ~/code/road_ai_analysis/.env ~/code/sam3/.env
|
| 213 |
-
|
| 214 |
-
# Check config
|
| 215 |
-
cat config/config.json
|
| 216 |
-
```
|
| 217 |
-
|
| 218 |
-
### Step 2: Install Dependencies
|
| 219 |
-
```bash
|
| 220 |
-
pip install opencv-python numpy requests pydantic pillow scipy python-dotenv
|
| 221 |
-
```
|
| 222 |
-
|
| 223 |
-
### Step 3: Start Implementation
|
| 224 |
-
Follow TODO.md phase by phase. Start with extraction:
|
| 225 |
-
|
| 226 |
-
```bash
|
| 227 |
-
# Create extraction module
|
| 228 |
-
touch extraction/cvat_extractor.py
|
| 229 |
-
|
| 230 |
-
# Implement following the TODO.md guidance
|
| 231 |
-
# Test each function as you write it
|
| 232 |
-
```
|
| 233 |
-
|
| 234 |
-
### Step 4: Test Incrementally
|
| 235 |
-
```bash
|
| 236 |
-
# Test CVAT connection first
|
| 237 |
-
python -c "from extraction.cvat_extractor import connect_to_cvat; ..."
|
| 238 |
-
|
| 239 |
-
# Test on 1 image before batch processing
|
| 240 |
-
# Use small dataset (5 images) for integration test
|
| 241 |
-
```
|
| 242 |
-
|
| 243 |
-
### Step 5: Run Full Evaluation
|
| 244 |
-
```bash
|
| 245 |
-
python run_evaluation.py --visualize
|
| 246 |
-
```
|
| 247 |
-
|
| 248 |
-
### Step 6: Review Results
|
| 249 |
-
```bash
|
| 250 |
-
# Check metrics
|
| 251 |
-
cat .cache/test/metrics/metrics_summary.txt
|
| 252 |
-
|
| 253 |
-
# Review visualizations
|
| 254 |
-
ls .cache/test/metrics/Fissure/*/comparison.png
|
| 255 |
-
|
| 256 |
-
# Read detailed report
|
| 257 |
-
cat EVALUATION_REPORT.md
|
| 258 |
-
```
|
| 259 |
-
|
| 260 |
-
## 🎯 Success Criteria
|
| 261 |
-
|
| 262 |
-
- [ ] Connect to CVAT successfully
|
| 263 |
-
- [ ] Extract 150 images (50 per class)
|
| 264 |
-
- [ ] All ground truth masks saved as PNG
|
| 265 |
-
- [ ] SAM3 inference completes for all images
|
| 266 |
-
- [ ] Metrics computed without errors
|
| 267 |
-
- [ ] Confusion matrices generated
|
| 268 |
-
- [ ] Visual comparisons created
|
| 269 |
-
- [ ] Report documents findings
|
| 270 |
-
- [ ] Results reviewed and validated
|
| 271 |
-
|
| 272 |
-
## ⚠️ Known Limitations
|
| 273 |
-
|
| 274 |
-
1. **HuggingFace Push Blocked**:
|
| 275 |
-
- GitHub: ✅ Updated successfully
|
| 276 |
-
- HuggingFace: ❌ Blocks .env in history
|
| 277 |
-
- **Not critical**: Work continues on GitHub
|
| 278 |
-
- **If needed**: Can manually push cleaned history
|
| 279 |
-
|
| 280 |
-
2. **Test Images**:
|
| 281 |
-
- Current test suite has only 1 real road damage image
|
| 282 |
-
- Need to manually download more from datasets
|
| 283 |
-
- Not critical for metrics evaluation (uses CVAT data)
|
| 284 |
-
|
| 285 |
-
## 📝 Git Status
|
| 286 |
-
|
| 287 |
-
- ✅ All work committed
|
| 288 |
-
- ✅ Pushed to GitHub (github.com:logiroad/sam3)
|
| 289 |
-
- ⚠️ HuggingFace push blocked (secret detection)
|
| 290 |
-
- ✅ .env removed from history
|
| 291 |
-
- ✅ .env.example created
|
| 292 |
-
|
| 293 |
-
## 🚀 Ready to Go!
|
| 294 |
-
|
| 295 |
-
The complete framework is in place. All planning, documentation, and infrastructure are ready. Implementation can proceed systematically following the TODO.md roadmap.
|
| 296 |
-
|
| 297 |
-
**Estimated completion time**: 12-18 hours of focused development
|
| 298 |
-
|
| 299 |
-
**Next immediate action**: Implement `extraction/cvat_extractor.py` following TODO.md Phase 2
|
| 300 |
-
|
| 301 |
-
---
|
| 302 |
-
|
| 303 |
-
## 📞 Questions?
|
| 304 |
-
|
| 305 |
-
Everything is documented:
|
| 306 |
-
- **Usage**: Read README.md
|
| 307 |
-
- **Implementation**: Follow TODO.md
|
| 308 |
-
- **Technical details**: Check IMPLEMENTATION_STATUS.md
|
| 309 |
-
- **Code standards**: Follow CODE_GUIDE.md
|
| 310 |
-
|
| 311 |
-
**The system is designed to be completely autonomous once implementation begins.**
|
| 312 |
-
|
| 313 |
-
---
|
| 314 |
-
|
| 315 |
-
*Generated by Claude Code on November 23, 2025, 02:20 AM*
|
| 316 |
-
*Total time invested: ~4 hours of planning, structure, and documentation*
|
| 317 |
-
*Production-ready framework awaiting implementation*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -10,9 +10,9 @@ library_name: transformers
|
|
| 10 |
pipeline_tag: image-segmentation
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# SAM3 -
|
| 14 |
|
| 15 |
-
SAM3 is
|
| 16 |
|
| 17 |
## 🚀 Deployment
|
| 18 |
|
|
@@ -24,16 +24,24 @@ SAM3 is a semantic segmentation model deployed as a custom Docker container on H
|
|
| 24 |
|
| 25 |
## 📊 Model Architecture
|
| 26 |
|
| 27 |
-
Built on Meta's SAM3 (Segment Anything Model 3) architecture for text-prompted
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
## 🎯 Usage
|
| 30 |
|
|
|
|
|
|
|
| 31 |
```python
|
| 32 |
import requests
|
| 33 |
import base64
|
| 34 |
|
| 35 |
# Read image
|
| 36 |
-
with open("
|
| 37 |
image_b64 = base64.b64encode(f.read()).decode()
|
| 38 |
|
| 39 |
# Call endpoint
|
|
@@ -41,14 +49,118 @@ response = requests.post(
|
|
| 41 |
"https://p6irm2x7y9mwp4l4.us-east-1.aws.endpoints.huggingface.cloud",
|
| 42 |
json={
|
| 43 |
"inputs": image_b64,
|
| 44 |
-
"parameters": {"classes": ["
|
| 45 |
}
|
| 46 |
)
|
| 47 |
|
| 48 |
-
# Get results
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
```
|
| 53 |
|
| 54 |
## 📦 Deployment
|
|
|
|
| 10 |
pipeline_tag: image-segmentation
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SAM3 - Instance Segmentation for Road Damage Detection
|
| 14 |
|
| 15 |
+
SAM3 is an instance segmentation model deployed as a custom Docker container on HuggingFace Inference Endpoints. It detects and segments individual instances of road damage (potholes, cracks) using text prompts.
|
| 16 |
|
| 17 |
## 🚀 Deployment
|
| 18 |
|
|
|
|
| 24 |
|
| 25 |
## 📊 Model Architecture
|
| 26 |
|
| 27 |
+
Built on Meta's SAM3 (Segment Anything Model 3) architecture for text-prompted **instance segmentation** of static images. SAM3 detects and segments all individual instances of specified object classes.
|
| 28 |
+
|
| 29 |
+
**Key features**:
|
| 30 |
+
- Multiple instances per class (e.g., 3 potholes in one image)
|
| 31 |
+
- Text-based prompting (natural language class names)
|
| 32 |
+
- High-quality segmentation masks
|
| 33 |
+
- Confidence scores per instance
|
| 34 |
|
| 35 |
## 🎯 Usage
|
| 36 |
|
| 37 |
+
### Basic Example
|
| 38 |
+
|
| 39 |
```python
|
| 40 |
import requests
|
| 41 |
import base64
|
| 42 |
|
| 43 |
# Read image
|
| 44 |
+
with open("road_image.jpg", "rb") as f:
|
| 45 |
image_b64 = base64.b64encode(f.read()).decode()
|
| 46 |
|
| 47 |
# Call endpoint
|
|
|
|
| 49 |
"https://p6irm2x7y9mwp4l4.us-east-1.aws.endpoints.huggingface.cloud",
|
| 50 |
json={
|
| 51 |
"inputs": image_b64,
|
| 52 |
+
"parameters": {"classes": ["Pothole", "Road crack", "Road"]}
|
| 53 |
}
|
| 54 |
)
|
| 55 |
|
| 56 |
+
# Get results - RETURNS VARIABLE NUMBER OF INSTANCES
|
| 57 |
+
instances = response.json()
|
| 58 |
+
print(f"Detected {len(instances)} instance(s)")
|
| 59 |
+
|
| 60 |
+
for instance in instances:
|
| 61 |
+
label = instance['label']
|
| 62 |
+
score = instance['score']
|
| 63 |
+
instance_id = instance['instance_id']
|
| 64 |
+
mask_b64 = instance['mask']
|
| 65 |
+
|
| 66 |
+
print(f"{label} #{instance_id}: confidence={score:.2f}")
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Response Format
|
| 70 |
+
|
| 71 |
+
The endpoint returns a **list of instances** (NOT one per class):
|
| 72 |
+
|
| 73 |
+
```json
|
| 74 |
+
[
|
| 75 |
+
{
|
| 76 |
+
"label": "Pothole",
|
| 77 |
+
"mask": "iVBORw0KG...",
|
| 78 |
+
"score": 0.92,
|
| 79 |
+
"instance_id": 0
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"label": "Pothole",
|
| 83 |
+
"mask": "iVBORw0KG...",
|
| 84 |
+
"score": 0.71,
|
| 85 |
+
"instance_id": 1
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"label": "Road crack",
|
| 89 |
+
"mask": "iVBORw0KG...",
|
| 90 |
+
"score": 0.38,
|
| 91 |
+
"instance_id": 0
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"label": "Road",
|
| 95 |
+
"mask": "iVBORw0KG...",
|
| 96 |
+
"score": 0.89,
|
| 97 |
+
"instance_id": 0
|
| 98 |
+
}
|
| 99 |
+
]
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
**Fields**:
|
| 103 |
+
- `label`: Class name (from input prompts)
|
| 104 |
+
- `mask`: Base64-encoded PNG mask (grayscale, 0-255)
|
| 105 |
+
- `score`: Confidence score (0.0-1.0)
|
| 106 |
+
- `instance_id`: Instance number within the class (0, 1, 2...)
|
| 107 |
+
|
| 108 |
+
### Processing Results
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
# Group instances by class
|
| 112 |
+
from collections import defaultdict
|
| 113 |
+
|
| 114 |
+
instances_by_class = defaultdict(list)
|
| 115 |
+
for instance in instances:
|
| 116 |
+
instances_by_class[instance['label']].append(instance)
|
| 117 |
+
|
| 118 |
+
# Count instances per class
|
| 119 |
+
for cls, insts in instances_by_class.items():
|
| 120 |
+
print(f"{cls}: {len(insts)} instance(s)")
|
| 121 |
+
|
| 122 |
+
# Get highest confidence instance per class
|
| 123 |
+
best_instances = {}
|
| 124 |
+
for cls, insts in instances_by_class.items():
|
| 125 |
+
best = max(insts, key=lambda x: x['score'])
|
| 126 |
+
best_instances[cls] = best
|
| 127 |
+
|
| 128 |
+
# Decode and visualize masks
|
| 129 |
+
import base64
|
| 130 |
+
from PIL import Image
|
| 131 |
+
import io
|
| 132 |
+
|
| 133 |
+
for instance in instances:
|
| 134 |
+
mask_bytes = base64.b64decode(instance['mask'])
|
| 135 |
+
mask_img = Image.open(io.BytesIO(mask_bytes))
|
| 136 |
+
# mask_img is now a PIL Image (grayscale)
|
| 137 |
+
mask_img.save(f"{instance['label']}_{instance['instance_id']}.png")
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## ⚙️ Model Parameters
|
| 141 |
+
|
| 142 |
+
- **Detection threshold**: 0.3 (instances with score < 0.3 are filtered out)
|
| 143 |
+
- **Mask threshold**: 0.5 (pixel probability threshold for mask generation)
|
| 144 |
+
- **Max instances**: Up to 200 per image (DETR architecture limit)
|
| 145 |
+
|
| 146 |
+
## 🎨 Use Cases
|
| 147 |
+
|
| 148 |
+
**Road Damage Detection**:
|
| 149 |
+
```python
|
| 150 |
+
classes = ["Pothole", "Road crack", "Road"]
|
| 151 |
+
# Detects: multiple potholes, multiple cracks, road surface
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
**Traffic Infrastructure**:
|
| 155 |
+
```python
|
| 156 |
+
classes = ["Traffic sign", "Traffic light", "Road marking"]
|
| 157 |
+
# Detects: all signs, all lights, all markings in view
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
**General Object Detection**:
|
| 161 |
+
```python
|
| 162 |
+
classes = ["car", "person", "bicycle"]
|
| 163 |
+
# Detects: all cars, all people, all bicycles
|
| 164 |
```
|
| 165 |
|
| 166 |
## 📦 Deployment
|
|
@@ -1,16 +1,29 @@
|
|
| 1 |
# SAM3 Testing Guide
|
| 2 |
|
| 3 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
### Test Infrastructure
|
| 6 |
|
| 7 |
-
|
| 8 |
- Tests multiple images automatically
|
| 9 |
-
- Saves detailed JSON logs of requests and responses
|
| 10 |
- Generates visualizations with semi-transparent colored masks
|
| 11 |
- Stores all results in `.cache/test/inference/{image_name}/`
|
| 12 |
|
| 13 |
-
### Running Tests
|
| 14 |
|
| 15 |
```bash
|
| 16 |
python3 scripts/test/test_inference_comprehensive.py
|
|
@@ -18,7 +31,7 @@ python3 scripts/test/test_inference_comprehensive.py
|
|
| 18 |
|
| 19 |
### Test Output Structure
|
| 20 |
|
| 21 |
-
For each test image,
|
| 22 |
|
| 23 |
- `request.json` - Request metadata (timestamp, endpoint, classes)
|
| 24 |
- `response.json` - Response metadata (timestamp, status, results summary)
|
|
@@ -28,18 +41,35 @@ For each test image, the following files are generated in `.cache/test/inference
|
|
| 28 |
- `legend.png` - Legend showing class colors and coverage percentages
|
| 29 |
- `mask_{ClassName}.png` - Individual binary masks for each class
|
| 30 |
|
| 31 |
-
### Classes
|
| 32 |
|
| 33 |
The endpoint is tested with these semantic classes:
|
| 34 |
- **Pothole** (Red overlay)
|
| 35 |
- **Road crack** (Yellow overlay)
|
| 36 |
- **Road** (Blue overlay)
|
| 37 |
|
| 38 |
-
### Test
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
1. **Download from Public Datasets**:
|
| 45 |
- [Pothole Detection Dataset](https://github.com/jaygala24/pothole-detection/releases/download/v1.0.0/Pothole.Dataset.IVCNZ.zip) (1,243 images)
|
|
@@ -50,19 +80,162 @@ Test images should be placed in `assets/test_images/`.
|
|
| 50 |
|
| 51 |
3. **Place in Test Directory**: Copy to `assets/test_images/`
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
- Review results without cluttering the repository
|
| 57 |
- Compare results across different test runs
|
| 58 |
- Debug segmentation quality issues
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
-
|
|
|
|
| 1 |
# SAM3 Testing Guide
|
| 2 |
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This guide covers two testing approaches for SAM3:
|
| 6 |
+
|
| 7 |
+
1. **Basic Inference Testing** - Quick API validation with sample images
|
| 8 |
+
2. **Metrics Evaluation** - Comprehensive performance analysis against CVAT ground truth
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## 1. Basic Inference Testing
|
| 13 |
+
|
| 14 |
+
### Purpose
|
| 15 |
+
|
| 16 |
+
Quickly validate that the SAM3 endpoint is working and producing reasonable segmentation results.
|
| 17 |
|
| 18 |
### Test Infrastructure
|
| 19 |
|
| 20 |
+
The basic testing framework:
|
| 21 |
- Tests multiple images automatically
|
| 22 |
+
- Saves detailed JSON logs of requests and responses
|
| 23 |
- Generates visualizations with semi-transparent colored masks
|
| 24 |
- Stores all results in `.cache/test/inference/{image_name}/`
|
| 25 |
|
| 26 |
+
### Running Basic Tests
|
| 27 |
|
| 28 |
```bash
|
| 29 |
python3 scripts/test/test_inference_comprehensive.py
|
|
|
|
| 31 |
|
| 32 |
### Test Output Structure
|
| 33 |
|
| 34 |
+
For each test image, files are generated in `.cache/test/inference/{image_name}/`:
|
| 35 |
|
| 36 |
- `request.json` - Request metadata (timestamp, endpoint, classes)
|
| 37 |
- `response.json` - Response metadata (timestamp, status, results summary)
|
|
|
|
| 41 |
- `legend.png` - Legend showing class colors and coverage percentages
|
| 42 |
- `mask_{ClassName}.png` - Individual binary masks for each class
|
| 43 |
|
| 44 |
+
### Tested Classes
|
| 45 |
|
| 46 |
The endpoint is tested with these semantic classes:
|
| 47 |
- **Pothole** (Red overlay)
|
| 48 |
- **Road crack** (Yellow overlay)
|
| 49 |
- **Road** (Blue overlay)
|
| 50 |
|
| 51 |
+
### Recent Test Results
|
| 52 |
+
|
| 53 |
+
**Last run**: November 23, 2025
|
| 54 |
|
| 55 |
+
- **Total images**: 8
|
| 56 |
+
- **Successful**: 8/8 (100%)
|
| 57 |
+
- **Failed**: 0
|
| 58 |
+
- **Average response time**: ~1.5 seconds per image
|
| 59 |
+
- **Status**: All API calls returning HTTP 200 with valid masks
|
| 60 |
|
| 61 |
+
Test images include:
|
| 62 |
+
- `pothole_pexels_01.jpg`, `pothole_pexels_02.jpg`
|
| 63 |
+
- `road_damage_01.jpg`
|
| 64 |
+
- `road_pexels_01.jpg`, `road_pexels_02.jpg`, `road_pexels_03.jpg`
|
| 65 |
+
- `road_unsplash_01.jpg`
|
| 66 |
+
- `test.jpg`
|
| 67 |
+
|
| 68 |
+
Results stored in `.cache/test/inference/summary.json`
|
| 69 |
+
|
| 70 |
+
### Adding More Test Images
|
| 71 |
+
|
| 72 |
+
Test images should be placed in `assets/test_images/`. To expand the test suite:
|
| 73 |
|
| 74 |
1. **Download from Public Datasets**:
|
| 75 |
- [Pothole Detection Dataset](https://github.com/jaygala24/pothole-detection/releases/download/v1.0.0/Pothole.Dataset.IVCNZ.zip) (1,243 images)
|
|
|
|
| 80 |
|
| 81 |
3. **Place in Test Directory**: Copy to `assets/test_images/`
|
| 82 |
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 2. Metrics Evaluation System
|
| 86 |
+
|
| 87 |
+
### Purpose
|
| 88 |
+
|
| 89 |
+
Comprehensive quantitative evaluation of SAM3 performance against ground truth annotations from CVAT.
|
| 90 |
+
|
| 91 |
+
### What It Measures
|
| 92 |
+
|
| 93 |
+
- **mAP (mean Average Precision)**: Detection accuracy across all confidence thresholds
|
| 94 |
+
- **mAR (mean Average Recall)**: Coverage of ground truth instances
|
| 95 |
+
- **IoU metrics**: Intersection over Union at multiple thresholds (0%, 25%, 50%, 75%)
|
| 96 |
+
- **Confusion matrices**: Class prediction accuracy patterns
|
| 97 |
+
- **Per-class statistics**: Precision, recall, F1-score for each damage type
|
| 98 |
+
|
| 99 |
+
### Running Metrics Evaluation
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
cd metrics_evaluation
|
| 103 |
+
python run_evaluation.py
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
**Options**:
|
| 107 |
+
```bash
|
| 108 |
+
# Force re-download from CVAT (ignore cache)
|
| 109 |
+
python run_evaluation.py --force-download
|
| 110 |
+
|
| 111 |
+
# Force re-run inference (ignore cached predictions)
|
| 112 |
+
python run_evaluation.py --force-inference
|
| 113 |
+
|
| 114 |
+
# Skip inference step (use existing predictions)
|
| 115 |
+
python run_evaluation.py --skip-inference
|
| 116 |
+
|
| 117 |
+
# Generate visual comparisons
|
| 118 |
+
python run_evaluation.py --visualize
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### Dataset
|
| 122 |
+
|
| 123 |
+
Evaluates on **150 annotated images** from CVAT:
|
| 124 |
+
- **50 images** with "Fissure" (road cracks)
|
| 125 |
+
- **50 images** with "Nid de poule" (potholes)
|
| 126 |
+
- **50 images** with road surface
|
| 127 |
|
| 128 |
+
Source: Logiroad CVAT organization, AI training project
|
| 129 |
+
|
| 130 |
+
### Output Structure
|
| 131 |
+
|
| 132 |
+
```
|
| 133 |
+
.cache/test/metrics/
|
| 134 |
+
├── Fissure/
|
| 135 |
+
│ └── {image_name}/
|
| 136 |
+
│ ├── image.jpg
|
| 137 |
+
│ ├── ground_truth/
|
| 138 |
+
│ │ ├── mask_Fissure_0.png
|
| 139 |
+
│ │ └── metadata.json
|
| 140 |
+
│ └── inference/
|
| 141 |
+
│ ├── mask_Fissure_0.png
|
| 142 |
+
│ └── metadata.json
|
| 143 |
+
├── Nid de poule/
|
| 144 |
+
├── Road/
|
| 145 |
+
├── metrics_summary.txt # Human-readable results
|
| 146 |
+
├── metrics_detailed.json # Complete metrics data
|
| 147 |
+
└── evaluation_log.txt # Execution trace
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Execution Time
|
| 151 |
+
|
| 152 |
+
- Image download: ~5-10 minutes (150 images)
|
| 153 |
+
- SAM3 inference: ~5-10 minutes (~2s per image)
|
| 154 |
+
- Metrics computation: ~1 minute
|
| 155 |
+
- **Total**: ~15-20 minutes for full evaluation
|
| 156 |
+
|
| 157 |
+
### Configuration
|
| 158 |
+
|
| 159 |
+
Edit `metrics_evaluation/config/config.json` to:
|
| 160 |
+
- Change CVAT project or organization
|
| 161 |
+
- Adjust number of images per class
|
| 162 |
+
- Modify IoU thresholds
|
| 163 |
+
- Update SAM3 endpoint URL
|
| 164 |
+
|
| 165 |
+
CVAT credentials must be in `.env` at project root.
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## Cache Directory
|
| 170 |
+
|
| 171 |
+
All test results are stored in `.cache/` (git-ignored):
|
| 172 |
- Review results without cluttering the repository
|
| 173 |
- Compare results across different test runs
|
| 174 |
- Debug segmentation quality issues
|
| 175 |
+
- Resume interrupted evaluations
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## Quality Validation Checklist
|
| 180 |
+
|
| 181 |
+
Before accepting test results:
|
| 182 |
+
|
| 183 |
+
**Basic Tests**:
|
| 184 |
+
- [ ] All test images processed successfully
|
| 185 |
+
- [ ] Masks generated for all requested classes
|
| 186 |
+
- [ ] Response times reasonable (< 3s per image)
|
| 187 |
+
- [ ] Visualizations show plausible segmentations
|
| 188 |
+
|
| 189 |
+
**Metrics Evaluation**:
|
| 190 |
+
- [ ] 150 images downloaded from CVAT
|
| 191 |
+
- [ ] Ground truth masks not empty
|
| 192 |
+
- [ ] SAM3 inference completed for all images
|
| 193 |
+
- [ ] Metrics within reasonable ranges (0-100%)
|
| 194 |
+
- [ ] Confusion matrices show sensible patterns
|
| 195 |
+
- [ ] Per-class F1 scores above baseline
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## Troubleshooting
|
| 200 |
+
|
| 201 |
+
### Basic Inference Issues
|
| 202 |
+
|
| 203 |
+
**Endpoint not responding**:
|
| 204 |
+
- Check endpoint URL in test script
|
| 205 |
+
- Verify endpoint is running (use `curl` or browser)
|
| 206 |
+
- Check network connectivity
|
| 207 |
+
|
| 208 |
+
**Empty or invalid masks**:
|
| 209 |
+
- Review class names match model expectations
|
| 210 |
+
- Check image format (should be JPEG/PNG)
|
| 211 |
+
- Verify base64 encoding/decoding
|
| 212 |
+
|
| 213 |
+
### Metrics Evaluation Issues
|
| 214 |
+
|
| 215 |
+
**CVAT connection fails**:
|
| 216 |
+
- Check `.env` credentials
|
| 217 |
+
- Verify CVAT organization name
|
| 218 |
+
- Test CVAT web access
|
| 219 |
+
|
| 220 |
+
**No images found**:
|
| 221 |
+
- Check project filter in `config.json`
|
| 222 |
+
- Verify labels exist in CVAT
|
| 223 |
+
- Ensure images have annotations
|
| 224 |
+
|
| 225 |
+
**Metrics seem incorrect**:
|
| 226 |
+
- Inspect confusion matrices
|
| 227 |
+
- Review sample visualizations
|
| 228 |
+
- Check ground truth quality in CVAT
|
| 229 |
+
- Verify mask format (PNG-L, 8-bit grayscale)
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
|
| 233 |
+
## Next Steps
|
| 234 |
|
| 235 |
+
1. **Run basic tests** to validate API connectivity
|
| 236 |
+
2. **Review visualizations** to assess segmentation quality
|
| 237 |
+
3. **Run metrics evaluation** for quantitative performance
|
| 238 |
+
4. **Analyze confusion matrices** to identify systematic errors
|
| 239 |
+
5. **Iterate on model/prompts** based on metrics feedback
|
| 240 |
|
| 241 |
+
For detailed metrics evaluation documentation, see `metrics_evaluation/README.md`.
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
Git LFS Details
|
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Debug script to inspect CVAT labels and annotations."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from metrics_evaluation.cvat_api.client import CvatApiClient
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
# Connect to CVAT
|
| 10 |
+
client = CvatApiClient(
|
| 11 |
+
cvat_host="https://app.cvat.ai",
|
| 12 |
+
cvat_username=os.getenv("CVAT_USERNAME"),
|
| 13 |
+
cvat_password=os.getenv("CVAT_PASSWORD"),
|
| 14 |
+
cvat_organization="Logiroad",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Find the training project
|
| 18 |
+
projects = client.projects.list()
|
| 19 |
+
training_project = None
|
| 20 |
+
for project in projects:
|
| 21 |
+
if "Entrainement" in project.name:
|
| 22 |
+
training_project = project
|
| 23 |
+
break
|
| 24 |
+
|
| 25 |
+
if not training_project:
|
| 26 |
+
print("No training project found")
|
| 27 |
+
exit(1)
|
| 28 |
+
|
| 29 |
+
print(f"Project: {training_project.name} (ID: {training_project.id})")
|
| 30 |
+
|
| 31 |
+
# Get project labels
|
| 32 |
+
labels = client.projects.get_project_labels(training_project.id)
|
| 33 |
+
print(f"\nProject labels ({len(labels)}):")
|
| 34 |
+
for label in labels:
|
| 35 |
+
print(f" - {label.name} (ID: {label.id})")
|
| 36 |
+
|
| 37 |
+
# Get tasks
|
| 38 |
+
tasks = client.tasks.list(project_id=training_project.id)
|
| 39 |
+
print(f"\nTasks: {len(tasks)}")
|
| 40 |
+
|
| 41 |
+
# Check first few tasks for annotations
|
| 42 |
+
for i, task in enumerate(tasks[:3]):
|
| 43 |
+
print(f"\n--- Task {task.id}: {task.name} ---")
|
| 44 |
+
|
| 45 |
+
# Get jobs for this task
|
| 46 |
+
jobs = client.jobs.list(task_id=task.id)
|
| 47 |
+
print(f"Jobs: {len(jobs)}")
|
| 48 |
+
|
| 49 |
+
for job in jobs[:1]: # Just check first job
|
| 50 |
+
print(f" Job {job.id}:")
|
| 51 |
+
|
| 52 |
+
# Get annotations
|
| 53 |
+
annotations = client.annotations.get_job_annotations(job.id)
|
| 54 |
+
|
| 55 |
+
print(f" Tags: {len(annotations.tags)}")
|
| 56 |
+
print(f" Shapes: {len(annotations.shapes)}")
|
| 57 |
+
print(f" Tracks: {len(annotations.tracks)}")
|
| 58 |
+
|
| 59 |
+
# Show first few shapes
|
| 60 |
+
for j, shape in enumerate(annotations.shapes[:3]):
|
| 61 |
+
print(f" Shape {j}: type={shape.type}, label_id={shape.label_id}, label={shape.label}, frame={shape.frame}")
|
|
@@ -2,12 +2,11 @@
|
|
| 2 |
"cvat": {
|
| 3 |
"url": "https://app.cvat.ai",
|
| 4 |
"organization": "Logiroad",
|
| 5 |
-
"project_name_filter": "
|
| 6 |
},
|
| 7 |
"classes": {
|
| 8 |
"Fissure": 50,
|
| 9 |
-
"Nid de poule": 50
|
| 10 |
-
"Road": 50
|
| 11 |
},
|
| 12 |
"sam3": {
|
| 13 |
"endpoint": "https://p6irm2x7y9mwp4l4.us-east-1.aws.endpoints.huggingface.cloud",
|
|
|
|
| 2 |
"cvat": {
|
| 3 |
"url": "https://app.cvat.ai",
|
| 4 |
"organization": "Logiroad",
|
| 5 |
+
"project_name_filter": "Entrainement"
|
| 6 |
},
|
| 7 |
"classes": {
|
| 8 |
"Fissure": 50,
|
| 9 |
+
"Nid de poule": 50
|
|
|
|
| 10 |
},
|
| 11 |
"sam3": {
|
| 12 |
"endpoint": "https://p6irm2x7y9mwp4l4.us-east-1.aws.endpoints.huggingface.cloud",
|
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""CVAT API job methods."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from typing import TYPE_CHECKING
|
| 4 |
|
| 5 |
from metrics_evaluation.schema.cvat import (
|
|
@@ -35,6 +37,32 @@ class JobsMethods:
|
|
| 35 |
"""
|
| 36 |
self.client = client
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 39 |
def list_jobs(
|
| 40 |
self, request: CvatApiJobsListRequest, token: str | None = None
|
|
|
|
| 1 |
"""CVAT API job methods."""
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
from metrics_evaluation.schema.cvat import (
|
|
|
|
| 37 |
"""
|
| 38 |
self.client = client
|
| 39 |
|
| 40 |
+
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 41 |
+
def list(self, task_id: int | None = None, token: str | None = None) -> list[CvatApiJobDetails]:
|
| 42 |
+
"""List all jobs, optionally filtered by task.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
task_id: Filter by task ID (optional)
|
| 46 |
+
token: Authentication token (optional)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List of job details objects
|
| 50 |
+
"""
|
| 51 |
+
headers = self.client._get_headers(token)
|
| 52 |
+
url = f"{self.client.cvat_host}/api/jobs?page_size=1000"
|
| 53 |
+
if task_id is not None:
|
| 54 |
+
url += f"&task_id={task_id}"
|
| 55 |
+
|
| 56 |
+
response = self.client._make_request(
|
| 57 |
+
method="GET",
|
| 58 |
+
url=url,
|
| 59 |
+
headers=headers,
|
| 60 |
+
resource_name="jobs list",
|
| 61 |
+
response_model=CvatApiJobsListResponse,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return response.results
|
| 65 |
+
|
| 66 |
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 67 |
def list_jobs(
|
| 68 |
self, request: CvatApiJobsListRequest, token: str | None = None
|
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""CVAT API project methods."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from typing import TYPE_CHECKING
|
| 4 |
|
| 5 |
from metrics_evaluation.schema.cvat import CvatApiLabelDefinition, CvatApiProjectDetails
|
|
@@ -30,6 +32,32 @@ class ProjectsMethods:
|
|
| 30 |
"""
|
| 31 |
self.client = client
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 34 |
def get_project_details(
|
| 35 |
self, project_id: int, token: str | None = None
|
|
|
|
| 1 |
"""CVAT API project methods."""
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
from metrics_evaluation.schema.cvat import CvatApiLabelDefinition, CvatApiProjectDetails
|
|
|
|
| 32 |
"""
|
| 33 |
self.client = client
|
| 34 |
|
| 35 |
+
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 36 |
+
def list(self, token: str | None = None) -> list[CvatApiProjectDetails]:
|
| 37 |
+
"""List all projects accessible to the user.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
token: Authentication token (optional)
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
List of project details objects
|
| 44 |
+
"""
|
| 45 |
+
headers = self.client._get_headers(token)
|
| 46 |
+
url = f"{self.client.cvat_host}/api/projects?page_size=1000"
|
| 47 |
+
|
| 48 |
+
response = self.client._make_request(
|
| 49 |
+
method="GET",
|
| 50 |
+
url=url,
|
| 51 |
+
headers=headers,
|
| 52 |
+
resource_name="projects list",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
response_data = response.json()
|
| 56 |
+
return [
|
| 57 |
+
CvatApiProjectDetails.model_validate(project)
|
| 58 |
+
for project in response_data.get("results", [])
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 62 |
def get_project_details(
|
| 63 |
self, project_id: int, token: str | None = None
|
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""CVAT API task methods."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from typing import TYPE_CHECKING
|
| 4 |
|
| 5 |
from metrics_evaluation.schema.cvat import CvatApiTaskDetails, CvatApiTaskMediasMetainformation
|
|
@@ -30,6 +32,35 @@ class TasksMethods:
|
|
| 30 |
"""
|
| 31 |
self.client = client
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 34 |
def get_task_details(
|
| 35 |
self, task_id: int, token: str | None = None
|
|
@@ -133,3 +164,28 @@ class TasksMethods:
|
|
| 133 |
resource_id=task_id,
|
| 134 |
response_model=CvatApiTaskDetails,
|
| 135 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""CVAT API task methods."""
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
from metrics_evaluation.schema.cvat import CvatApiTaskDetails, CvatApiTaskMediasMetainformation
|
|
|
|
| 32 |
"""
|
| 33 |
self.client = client
|
| 34 |
|
| 35 |
+
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 36 |
+
def list(self, project_id: int | None = None, token: str | None = None) -> list[CvatApiTaskDetails]:
|
| 37 |
+
"""List all tasks, optionally filtered by project.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
project_id: Filter by project ID (optional)
|
| 41 |
+
token: Authentication token (optional)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
List of task details objects
|
| 45 |
+
"""
|
| 46 |
+
headers = self.client._get_headers(token)
|
| 47 |
+
url = f"{self.client.cvat_host}/api/tasks?page_size=1000"
|
| 48 |
+
if project_id is not None:
|
| 49 |
+
url += f"&project_id={project_id}"
|
| 50 |
+
|
| 51 |
+
response = self.client._make_request(
|
| 52 |
+
method="GET",
|
| 53 |
+
url=url,
|
| 54 |
+
headers=headers,
|
| 55 |
+
resource_name="tasks list",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
response_data = response.json()
|
| 59 |
+
return [
|
| 60 |
+
CvatApiTaskDetails.model_validate(task)
|
| 61 |
+
for task in response_data.get("results", [])
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 65 |
def get_task_details(
|
| 66 |
self, task_id: int, token: str | None = None
|
|
|
|
| 164 |
resource_id=task_id,
|
| 165 |
response_model=CvatApiTaskDetails,
|
| 166 |
)
|
| 167 |
+
|
| 168 |
+
@retry_with_backoff(max_retries=3, initial_delay=1.0)
|
| 169 |
+
def get_frame(self, task_id: int, frame_number: int, token: str | None = None) -> bytes:
|
| 170 |
+
"""Download a single frame from a task.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
task_id: The ID of the task
|
| 174 |
+
frame_number: The frame number to download
|
| 175 |
+
token: Authentication token (optional)
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Raw image bytes
|
| 179 |
+
"""
|
| 180 |
+
headers = self.client._get_headers(token)
|
| 181 |
+
url = f"{self.client.cvat_host}/api/tasks/{task_id}/data?type=frame&number={frame_number}&quality=original"
|
| 182 |
+
|
| 183 |
+
response = self.client._make_request(
|
| 184 |
+
method="GET",
|
| 185 |
+
url=url,
|
| 186 |
+
headers=headers,
|
| 187 |
+
resource_name="task frame",
|
| 188 |
+
resource_id=task_id,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return response.content
|
|
@@ -28,6 +28,7 @@ class CVATExtractor:
|
|
| 28 |
self.config = config
|
| 29 |
self.client: CvatApiClient | None = None
|
| 30 |
self.project_id: int | None = None
|
|
|
|
| 31 |
|
| 32 |
def connect(self) -> None:
|
| 33 |
"""Connect to CVAT API.
|
|
@@ -52,9 +53,10 @@ class CVATExtractor:
|
|
| 52 |
|
| 53 |
try:
|
| 54 |
self.client = CvatApiClient(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
)
|
| 59 |
logger.info(f"Connected to CVAT at {self.config.cvat.url}")
|
| 60 |
except Exception as e:
|
|
@@ -109,6 +111,11 @@ class CVATExtractor:
|
|
| 109 |
if not self.client or not self.project_id:
|
| 110 |
raise ValueError("Must connect and find project first")
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
tasks = self.client.tasks.list(project_id=self.project_id)
|
| 113 |
|
| 114 |
if not tasks:
|
|
@@ -142,7 +149,11 @@ class CVATExtractor:
|
|
| 142 |
|
| 143 |
# Check which classes are present in each frame
|
| 144 |
for frame_id, shapes in frame_annotations.items():
|
| 145 |
-
labels_in_frame = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
for class_name in self.config.classes.keys():
|
| 148 |
if class_name in labels_in_frame:
|
|
@@ -367,7 +378,9 @@ class CVATExtractor:
|
|
| 367 |
label_counts: dict[str, int] = {}
|
| 368 |
|
| 369 |
for shape in frame_masks:
|
| 370 |
-
label = shape.
|
|
|
|
|
|
|
| 371 |
if label not in label_counts:
|
| 372 |
label_counts[label] = 0
|
| 373 |
|
|
|
|
| 28 |
self.config = config
|
| 29 |
self.client: CvatApiClient | None = None
|
| 30 |
self.project_id: int | None = None
|
| 31 |
+
self.label_map: dict[int, str] = {}
|
| 32 |
|
| 33 |
def connect(self) -> None:
|
| 34 |
"""Connect to CVAT API.
|
|
|
|
| 53 |
|
| 54 |
try:
|
| 55 |
self.client = CvatApiClient(
|
| 56 |
+
cvat_host=self.config.cvat.url,
|
| 57 |
+
cvat_username=username,
|
| 58 |
+
cvat_password=password,
|
| 59 |
+
cvat_organization=self.config.cvat.organization,
|
| 60 |
)
|
| 61 |
logger.info(f"Connected to CVAT at {self.config.cvat.url}")
|
| 62 |
except Exception as e:
|
|
|
|
| 111 |
if not self.client or not self.project_id:
|
| 112 |
raise ValueError("Must connect and find project first")
|
| 113 |
|
| 114 |
+
# Get project labels to map label_id to label name
|
| 115 |
+
project_labels = self.client.projects.get_project_labels(self.project_id)
|
| 116 |
+
self.label_map = {label.id: label.name for label in project_labels}
|
| 117 |
+
logger.info(f"Loaded {len(self.label_map)} label definitions from project")
|
| 118 |
+
|
| 119 |
tasks = self.client.tasks.list(project_id=self.project_id)
|
| 120 |
|
| 121 |
if not tasks:
|
|
|
|
| 149 |
|
| 150 |
# Check which classes are present in each frame
|
| 151 |
for frame_id, shapes in frame_annotations.items():
|
| 152 |
+
labels_in_frame = {
|
| 153 |
+
self.label_map.get(shape.label_id)
|
| 154 |
+
for shape in shapes
|
| 155 |
+
if hasattr(shape, 'type') and shape.type == 'mask' and shape.label_id in self.label_map
|
| 156 |
+
}
|
| 157 |
|
| 158 |
for class_name in self.config.classes.keys():
|
| 159 |
if class_name in labels_in_frame:
|
|
|
|
| 378 |
label_counts: dict[str, int] = {}
|
| 379 |
|
| 380 |
for shape in frame_masks:
|
| 381 |
+
label = self.label_map.get(shape.label_id)
|
| 382 |
+
if not label:
|
| 383 |
+
continue
|
| 384 |
if label not in label_counts:
|
| 385 |
label_counts[label] = 0
|
| 386 |
|
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""SAM3 inference for evaluation."""
|
| 2 |
|
| 3 |
import base64
|
|
|
|
| 4 |
import json
|
| 5 |
import logging
|
| 6 |
import time
|
|
@@ -214,8 +215,6 @@ class SAM3Inferencer:
|
|
| 214 |
"skipped": 0,
|
| 215 |
}
|
| 216 |
|
| 217 |
-
import io
|
| 218 |
-
|
| 219 |
processed = 0
|
| 220 |
|
| 221 |
for class_name, paths in image_paths.items():
|
|
|
|
| 1 |
"""SAM3 inference for evaluation."""
|
| 2 |
|
| 3 |
import base64
|
| 4 |
+
import io
|
| 5 |
import json
|
| 6 |
import logging
|
| 7 |
import time
|
|
|
|
| 215 |
"skipped": 0,
|
| 216 |
}
|
| 217 |
|
|
|
|
|
|
|
| 218 |
processed = 0
|
| 219 |
|
| 220 |
for class_name, paths in image_paths.items():
|
|
@@ -69,11 +69,9 @@ class Request(BaseModel):
|
|
| 69 |
|
| 70 |
def run_inference(image_b64: str, classes: list, request_id: str):
|
| 71 |
"""
|
| 72 |
-
Sam3Model inference for static images with text prompts
|
| 73 |
|
| 74 |
-
|
| 75 |
-
- processor(images=image, text=text_prompts)
|
| 76 |
-
- model.forward(pixel_values, input_ids, ...)
|
| 77 |
"""
|
| 78 |
try:
|
| 79 |
# Decode image
|
|
@@ -90,7 +88,16 @@ def run_inference(image_b64: str, classes: list, request_id: str):
|
|
| 90 |
text=classes, # List of text prompts
|
| 91 |
return_tensors="pt"
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images")
|
|
|
|
| 94 |
|
| 95 |
# Move to GPU and match model dtype
|
| 96 |
if torch.cuda.is_available():
|
|
@@ -101,87 +108,136 @@ def run_inference(image_b64: str, classes: list, request_id: str):
|
|
| 101 |
}
|
| 102 |
logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})")
|
| 103 |
|
| 104 |
-
logger.info(f"[{request_id}] Input keys: {list(inputs.keys())}")
|
| 105 |
-
|
| 106 |
# Sam3Model Inference
|
| 107 |
with torch.no_grad():
|
| 108 |
-
# Sam3Model.forward() accepts pixel_values, input_ids, etc.
|
| 109 |
outputs = model(**inputs)
|
| 110 |
logger.info(f"[{request_id}] Forward pass successful!")
|
| 111 |
|
| 112 |
logger.info(f"[{request_id}] Output type: {type(outputs)}")
|
| 113 |
-
logger.info(f"[{request_id}] Output attributes: {dir(outputs)}")
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")
|
| 126 |
-
else:
|
| 127 |
-
logger.error(f"[{request_id}] Unexpected output format")
|
| 128 |
-
logger.error(f"Output attributes: {dir(outputs) if not isinstance(outputs, dict) else outputs.keys()}")
|
| 129 |
-
raise ValueError("Cannot find masks in model output")
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
results = []
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
else:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
# Convert to PNG
|
| 166 |
-
pil_mask = Image.fromarray(binary_mask, mode="L")
|
| 167 |
-
buf = io.BytesIO()
|
| 168 |
-
pil_mask.save(buf, format="PNG")
|
| 169 |
-
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 170 |
-
|
| 171 |
-
# Get confidence score if available
|
| 172 |
-
score = 1.0
|
| 173 |
-
if hasattr(outputs, 'pred_scores') and i < outputs.pred_scores.shape[1]:
|
| 174 |
-
score = float(outputs.pred_scores[0, i].cpu())
|
| 175 |
-
elif hasattr(outputs, 'scores') and i < len(outputs.scores):
|
| 176 |
-
score = float(outputs.scores[i].cpu() if hasattr(outputs.scores[i], 'cpu') else outputs.scores[i])
|
| 177 |
-
|
| 178 |
-
results.append({
|
| 179 |
-
"label": cls,
|
| 180 |
-
"mask": mask_b64,
|
| 181 |
-
"score": score
|
| 182 |
-
})
|
| 183 |
-
|
| 184 |
-
logger.info(f"[{request_id}] Completed: {len(results)} masks generated")
|
| 185 |
return results
|
| 186 |
|
| 187 |
except Exception as e:
|
|
|
|
| 69 |
|
| 70 |
def run_inference(image_b64: str, classes: list, request_id: str):
|
| 71 |
"""
|
| 72 |
+
Sam3Model inference for static images with text prompts.
|
| 73 |
|
| 74 |
+
Uses official SAM3 processor post-processing for correct mask generation.
|
|
|
|
|
|
|
| 75 |
"""
|
| 76 |
try:
|
| 77 |
# Decode image
|
|
|
|
| 88 |
text=classes, # List of text prompts
|
| 89 |
return_tensors="pt"
|
| 90 |
)
|
| 91 |
+
|
| 92 |
+
# Store original sizes for post-processing
|
| 93 |
+
# Format: [[height, width]] for EACH image in batch
|
| 94 |
+
# Since we repeat the image for each class, repeat the size too
|
| 95 |
+
original_size = [pil_image.size[1], pil_image.size[0]] # [height, width]
|
| 96 |
+
original_sizes = torch.tensor([original_size] * len(classes))
|
| 97 |
+
inputs["original_sizes"] = original_sizes
|
| 98 |
+
|
| 99 |
logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images")
|
| 100 |
+
logger.info(f"[{request_id}] Original size: {pil_image.size} (W x H)")
|
| 101 |
|
| 102 |
# Move to GPU and match model dtype
|
| 103 |
if torch.cuda.is_available():
|
|
|
|
| 108 |
}
|
| 109 |
logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})")
|
| 110 |
|
|
|
|
|
|
|
| 111 |
# Sam3Model Inference
|
| 112 |
with torch.no_grad():
|
|
|
|
| 113 |
outputs = model(**inputs)
|
| 114 |
logger.info(f"[{request_id}] Forward pass successful!")
|
| 115 |
|
| 116 |
logger.info(f"[{request_id}] Output type: {type(outputs)}")
|
|
|
|
| 117 |
|
| 118 |
+
# Use processor's official post-processing method
|
| 119 |
+
# This handles:
|
| 120 |
+
# - Logit to probability conversion (sigmoid)
|
| 121 |
+
# - Proper thresholding (default 0.5)
|
| 122 |
+
# - Resizing to original image dimensions
|
| 123 |
+
# - Score extraction
|
| 124 |
+
logger.info(f"[{request_id}] Using processor.post_process_instance_segmentation()")
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
processed = processor.post_process_instance_segmentation(
|
| 128 |
+
outputs,
|
| 129 |
+
threshold=0.3, # Score threshold for detections (lowered to detect road cracks)
|
| 130 |
+
mask_threshold=0.5, # Probability threshold for mask pixels
|
| 131 |
+
target_sizes=original_sizes.tolist()
|
| 132 |
+
)
|
| 133 |
+
# Returns a LIST of results, one per image in batch (one per class in our case)
|
| 134 |
+
|
| 135 |
+
logger.info(f"[{request_id}] Post-processing successful!")
|
| 136 |
+
logger.info(f"[{request_id}] Number of batched results: {len(processed)}")
|
| 137 |
+
|
| 138 |
+
except Exception as proc_error:
|
| 139 |
+
logger.error(f"[{request_id}] Post-processing failed: {proc_error}")
|
| 140 |
+
logger.info(f"[{request_id}] Falling back to manual processing")
|
| 141 |
+
|
| 142 |
+
# Fallback to manual processing with sigmoid fix
|
| 143 |
+
results = []
|
| 144 |
+
|
| 145 |
+
# Extract masks from outputs
|
| 146 |
+
if hasattr(outputs, 'pred_masks'):
|
| 147 |
+
pred_masks = outputs.pred_masks
|
| 148 |
+
elif hasattr(outputs, 'masks'):
|
| 149 |
+
pred_masks = outputs.masks
|
| 150 |
+
elif isinstance(outputs, dict) and 'pred_masks' in outputs:
|
| 151 |
+
pred_masks = outputs['pred_masks']
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError("Cannot find masks in model output")
|
| 154 |
+
|
| 155 |
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
for i, cls in enumerate(classes):
|
| 158 |
+
if i < pred_masks.shape[1]:
|
| 159 |
+
mask_tensor = pred_masks[0, i]
|
| 160 |
+
|
| 161 |
+
# Resize to original size
|
| 162 |
+
if mask_tensor.shape[-2:] != pil_image.size[::-1]:
|
| 163 |
+
mask_tensor = torch.nn.functional.interpolate(
|
| 164 |
+
mask_tensor.unsqueeze(0).unsqueeze(0),
|
| 165 |
+
size=pil_image.size[::-1],
|
| 166 |
+
mode='bilinear',
|
| 167 |
+
align_corners=False
|
| 168 |
+
).squeeze()
|
| 169 |
+
|
| 170 |
+
# CRITICAL FIX: Convert logits to probabilities THEN threshold
|
| 171 |
+
probs = torch.sigmoid(mask_tensor)
|
| 172 |
+
binary_mask = (probs > 0.5).float().cpu().numpy().astype("uint8") * 255
|
| 173 |
+
else:
|
| 174 |
+
binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8")
|
| 175 |
+
|
| 176 |
+
# Convert to PNG
|
| 177 |
+
pil_mask = Image.fromarray(binary_mask, mode="L")
|
| 178 |
+
buf = io.BytesIO()
|
| 179 |
+
pil_mask.save(buf, format="PNG")
|
| 180 |
+
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 181 |
+
|
| 182 |
+
# Extract score
|
| 183 |
+
score = 1.0
|
| 184 |
+
if hasattr(outputs, 'pred_logits') and i < outputs.pred_logits.shape[1]:
|
| 185 |
+
# Convert logits to probability
|
| 186 |
+
score = float(torch.sigmoid(outputs.pred_logits[0, i]).cpu())
|
| 187 |
+
|
| 188 |
+
results.append({
|
| 189 |
+
"label": cls,
|
| 190 |
+
"mask": mask_b64,
|
| 191 |
+
"score": score
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
logger.info(f"[{request_id}] Completed (fallback): {len(results)} masks generated")
|
| 195 |
+
return results
|
| 196 |
+
|
| 197 |
+
# Extract results from processor output
|
| 198 |
+
# CRITICAL: processor returns one result dict per class (batched)
|
| 199 |
+
# Each result dict contains MULTIPLE instances of that class
|
| 200 |
results = []
|
| 201 |
|
| 202 |
+
total_instances = 0
|
| 203 |
+
for i, cls in enumerate(classes):
|
| 204 |
+
class_result = processed[i] # Results for this specific class
|
| 205 |
|
| 206 |
+
num_instances = len(class_result['masks']) if 'masks' in class_result else 0
|
| 207 |
+
total_instances += num_instances
|
| 208 |
|
| 209 |
+
if num_instances > 0:
|
| 210 |
+
logger.info(f"[{request_id}] {cls}: {num_instances} instance(s) detected")
|
| 211 |
+
|
| 212 |
+
# Loop through ALL instances of this class
|
| 213 |
+
for j in range(num_instances):
|
| 214 |
+
# Get mask (already binary, resized to original size)
|
| 215 |
+
mask_np = class_result['masks'][j].cpu().numpy().astype("uint8") * 255
|
| 216 |
+
|
| 217 |
+
# Convert to PNG
|
| 218 |
+
pil_mask = Image.fromarray(mask_np, mode="L")
|
| 219 |
+
buf = io.BytesIO()
|
| 220 |
+
pil_mask.save(buf, format="PNG")
|
| 221 |
+
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 222 |
+
|
| 223 |
+
# Get score (already converted to probability by processor)
|
| 224 |
+
score = float(class_result['scores'][j]) if 'scores' in class_result else 1.0
|
| 225 |
+
|
| 226 |
+
# Calculate coverage for logging
|
| 227 |
+
coverage = (mask_np > 0).sum() / mask_np.size * 100
|
| 228 |
+
|
| 229 |
+
results.append({
|
| 230 |
+
"label": cls,
|
| 231 |
+
"mask": mask_b64,
|
| 232 |
+
"score": score,
|
| 233 |
+
"instance_id": j
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
logger.info(f"[{request_id}] └─ Instance {j}: score={score:.3f}, coverage={coverage:.2f}%")
|
| 237 |
else:
|
| 238 |
+
logger.info(f"[{request_id}] {cls}: No instances detected")
|
| 239 |
+
|
| 240 |
+
logger.info(f"[{request_id}] Completed: {total_instances} instance(s) across {len(classes)} class(es)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
return results
|
| 242 |
|
| 243 |
except Exception as e:
|
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM3 Static Image Segmentation - Correct Implementation
|
| 3 |
+
|
| 4 |
+
Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation.
|
| 5 |
+
"""
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
import asyncio
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from fastapi import FastAPI, HTTPException
|
| 13 |
+
from pydantic import BaseModel
|
| 14 |
+
from transformers import AutoProcessor, AutoModel
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Load SAM3 model for STATIC IMAGES
|
| 21 |
+
processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True)
|
| 22 |
+
model = AutoModel.from_pretrained(
|
| 23 |
+
"./model",
|
| 24 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 25 |
+
trust_remote_code=True
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
model.eval()
|
| 29 |
+
if torch.cuda.is_available():
|
| 30 |
+
model.cuda()
|
| 31 |
+
logger.info(f"GPU: {torch.cuda.get_device_name()}")
|
| 32 |
+
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 33 |
+
|
| 34 |
+
logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation")
|
| 35 |
+
|
| 36 |
+
# Simple concurrency control
|
| 37 |
+
class VRAMManager:
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.semaphore = asyncio.Semaphore(2)
|
| 40 |
+
self.processing_count = 0
|
| 41 |
+
|
| 42 |
+
def get_vram_status(self):
|
| 43 |
+
if not torch.cuda.is_available():
|
| 44 |
+
return {}
|
| 45 |
+
return {
|
| 46 |
+
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
|
| 47 |
+
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 48 |
+
"free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9,
|
| 49 |
+
"processing_now": self.processing_count
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
async def acquire(self, rid):
|
| 53 |
+
await self.semaphore.acquire()
|
| 54 |
+
self.processing_count += 1
|
| 55 |
+
|
| 56 |
+
def release(self, rid):
|
| 57 |
+
self.processing_count -= 1
|
| 58 |
+
self.semaphore.release()
|
| 59 |
+
if torch.cuda.is_available():
|
| 60 |
+
torch.cuda.empty_cache()
|
| 61 |
+
|
| 62 |
+
vram_manager = VRAMManager()
|
| 63 |
+
app = FastAPI(title="SAM3 Static Image API")
|
| 64 |
+
|
| 65 |
+
class Request(BaseModel):
|
| 66 |
+
inputs: str
|
| 67 |
+
parameters: dict
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def run_inference(image_b64: str, classes: list, request_id: str):
|
| 71 |
+
"""
|
| 72 |
+
Sam3Model inference for static images with text prompts
|
| 73 |
+
|
| 74 |
+
According to HuggingFace docs, Sam3Model uses:
|
| 75 |
+
- processor(images=image, text=text_prompts)
|
| 76 |
+
- model.forward(pixel_values, input_ids, ...)
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
# Decode image
|
| 80 |
+
image_bytes = base64.b64decode(image_b64)
|
| 81 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 82 |
+
logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}")
|
| 83 |
+
|
| 84 |
+
# Process with Sam3Processor
|
| 85 |
+
# Sam3Model expects: batch of images matching text prompts
|
| 86 |
+
# For multiple objects in ONE image, repeat the image for each class
|
| 87 |
+
images_batch = [pil_image] * len(classes)
|
| 88 |
+
inputs = processor(
|
| 89 |
+
images=images_batch, # Repeat image for each text prompt
|
| 90 |
+
text=classes, # List of text prompts
|
| 91 |
+
return_tensors="pt"
|
| 92 |
+
)
|
| 93 |
+
logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images")
|
| 94 |
+
|
| 95 |
+
# Move to GPU and match model dtype
|
| 96 |
+
if torch.cuda.is_available():
|
| 97 |
+
model_dtype = next(model.parameters()).dtype
|
| 98 |
+
inputs = {
|
| 99 |
+
k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v
|
| 100 |
+
for k, v in inputs.items()
|
| 101 |
+
}
|
| 102 |
+
logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})")
|
| 103 |
+
|
| 104 |
+
logger.info(f"[{request_id}] Input keys: {list(inputs.keys())}")
|
| 105 |
+
|
| 106 |
+
# Sam3Model Inference
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
# Sam3Model.forward() accepts pixel_values, input_ids, etc.
|
| 109 |
+
outputs = model(**inputs)
|
| 110 |
+
logger.info(f"[{request_id}] Forward pass successful!")
|
| 111 |
+
|
| 112 |
+
logger.info(f"[{request_id}] Output type: {type(outputs)}")
|
| 113 |
+
logger.info(f"[{request_id}] Output attributes: {dir(outputs)}")
|
| 114 |
+
|
| 115 |
+
# Extract masks from outputs
|
| 116 |
+
# Sam3Model returns masks in outputs.pred_masks
|
| 117 |
+
if hasattr(outputs, 'pred_masks'):
|
| 118 |
+
pred_masks = outputs.pred_masks
|
| 119 |
+
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")
|
| 120 |
+
elif hasattr(outputs, 'masks'):
|
| 121 |
+
pred_masks = outputs.masks
|
| 122 |
+
logger.info(f"[{request_id}] masks shape: {pred_masks.shape}")
|
| 123 |
+
elif isinstance(outputs, dict) and 'pred_masks' in outputs:
|
| 124 |
+
pred_masks = outputs['pred_masks']
|
| 125 |
+
logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}")
|
| 126 |
+
else:
|
| 127 |
+
logger.error(f"[{request_id}] Unexpected output format")
|
| 128 |
+
logger.error(f"Output attributes: {dir(outputs) if not isinstance(outputs, dict) else outputs.keys()}")
|
| 129 |
+
raise ValueError("Cannot find masks in model output")
|
| 130 |
+
|
| 131 |
+
# Process masks
|
| 132 |
+
results = []
|
| 133 |
+
|
| 134 |
+
# pred_masks typically: [batch, num_objects, height, width]
|
| 135 |
+
batch_size = pred_masks.shape[0]
|
| 136 |
+
num_masks = pred_masks.shape[1] if len(pred_masks.shape) > 1 else 1
|
| 137 |
+
|
| 138 |
+
logger.info(f"[{request_id}] Batch size: {batch_size}, Num masks: {num_masks}")
|
| 139 |
+
|
| 140 |
+
for i, cls in enumerate(classes):
|
| 141 |
+
if i < num_masks:
|
| 142 |
+
# Get mask for this class/object
|
| 143 |
+
if len(pred_masks.shape) == 4: # [batch, num, h, w]
|
| 144 |
+
mask_tensor = pred_masks[0, i] # [h, w]
|
| 145 |
+
elif len(pred_masks.shape) == 3: # [num, h, w]
|
| 146 |
+
mask_tensor = pred_masks[i]
|
| 147 |
+
else:
|
| 148 |
+
mask_tensor = pred_masks
|
| 149 |
+
|
| 150 |
+
# Resize to original size if needed
|
| 151 |
+
if mask_tensor.shape[-2:] != pil_image.size[::-1]:
|
| 152 |
+
mask_tensor = torch.nn.functional.interpolate(
|
| 153 |
+
mask_tensor.unsqueeze(0).unsqueeze(0),
|
| 154 |
+
size=pil_image.size[::-1],
|
| 155 |
+
mode='bilinear',
|
| 156 |
+
align_corners=False
|
| 157 |
+
).squeeze()
|
| 158 |
+
|
| 159 |
+
# Convert to binary mask
|
| 160 |
+
binary_mask = (mask_tensor > 0.0).float().cpu().numpy().astype("uint8") * 255
|
| 161 |
+
else:
|
| 162 |
+
# No mask available for this class
|
| 163 |
+
binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8")
|
| 164 |
+
|
| 165 |
+
# Convert to PNG
|
| 166 |
+
pil_mask = Image.fromarray(binary_mask, mode="L")
|
| 167 |
+
buf = io.BytesIO()
|
| 168 |
+
pil_mask.save(buf, format="PNG")
|
| 169 |
+
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 170 |
+
|
| 171 |
+
# Get confidence score if available
|
| 172 |
+
score = 1.0
|
| 173 |
+
if hasattr(outputs, 'pred_scores') and i < outputs.pred_scores.shape[1]:
|
| 174 |
+
score = float(outputs.pred_scores[0, i].cpu())
|
| 175 |
+
elif hasattr(outputs, 'scores') and i < len(outputs.scores):
|
| 176 |
+
score = float(outputs.scores[i].cpu() if hasattr(outputs.scores[i], 'cpu') else outputs.scores[i])
|
| 177 |
+
|
| 178 |
+
results.append({
|
| 179 |
+
"label": cls,
|
| 180 |
+
"mask": mask_b64,
|
| 181 |
+
"score": score
|
| 182 |
+
})
|
| 183 |
+
|
| 184 |
+
logger.info(f"[{request_id}] Completed: {len(results)} masks generated")
|
| 185 |
+
return results
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.error(f"[{request_id}] Failed: {str(e)}")
|
| 189 |
+
import traceback
|
| 190 |
+
traceback.print_exc()
|
| 191 |
+
raise
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@app.post("/")
|
| 195 |
+
async def predict(req: Request):
|
| 196 |
+
request_id = str(id(req))[:8]
|
| 197 |
+
try:
|
| 198 |
+
await vram_manager.acquire(request_id)
|
| 199 |
+
try:
|
| 200 |
+
results = await asyncio.to_thread(
|
| 201 |
+
run_inference,
|
| 202 |
+
req.inputs,
|
| 203 |
+
req.parameters.get("classes", []),
|
| 204 |
+
request_id
|
| 205 |
+
)
|
| 206 |
+
return results
|
| 207 |
+
finally:
|
| 208 |
+
vram_manager.release(request_id)
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"[{request_id}] Error: {str(e)}")
|
| 211 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@app.get("/health")
|
| 215 |
+
async def health():
|
| 216 |
+
return {
|
| 217 |
+
"status": "healthy",
|
| 218 |
+
"model": model.__class__.__name__,
|
| 219 |
+
"gpu_available": torch.cuda.is_available(),
|
| 220 |
+
"vram": vram_manager.get_vram_status()
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@app.get("/metrics")
|
| 225 |
+
async def metrics():
|
| 226 |
+
return vram_manager.get_vram_status()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
import uvicorn
|
| 231 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
|