Initial commit
Browse files- DEPLOYMENT.md +257 -0
- Dockerfile +101 -0
- app/__init__.py +9 -0
- app/api/__init__.py +1 -0
- app/api/dependencies.py +47 -0
- app/api/v1/__init__.py +1 -0
- app/api/v1/api.py +20 -0
- app/api/v1/endpoints/__init__.py +1 -0
- app/api/v1/endpoints/health.py +150 -0
- app/api/v1/endpoints/inference.py +163 -0
- app/core/__init__.py +1 -0
- app/core/config.py +54 -0
- app/core/logging.py +41 -0
- app/main.py +119 -0
- app/models/__init__.py +1 -0
- app/models/inference.py +123 -0
- app/models/yolo.py +186 -0
- app/services/__init__.py +1 -0
- app/services/inference_service.py +201 -0
- app/services/model_service.py +166 -0
- app/utils/__init__.py +1 -0
- app/utils/image_processing.py +193 -0
- app/utils/model_utils.py +227 -0
- app/utils/performance.py +237 -0
- requirements.txt +29 -0
- run_tests.py +76 -0
- start_api.py +131 -0
- test_api_simple.py +173 -0
- tests/__init__.py +1 -0
- tests/test_api.py +175 -0
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🐟 Marine Species Identification API - Deployment Guide
|
| 2 |
+
|
| 3 |
+
## Quick Start
|
| 4 |
+
|
| 5 |
+
### 1. Local Development
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# Install dependencies
|
| 9 |
+
pip install -r requirements.txt
|
| 10 |
+
|
| 11 |
+
# Start the API (with automatic model download)
|
| 12 |
+
python start_api.py
|
| 13 |
+
|
| 14 |
+
# Or start directly with uvicorn
|
| 15 |
+
uvicorn app.main:app --host 0.0.0.0 --port 7860
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
### 2. Docker Deployment
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
# Build the Docker image
|
| 22 |
+
docker build -t marine-species-api .
|
| 23 |
+
|
| 24 |
+
# Run the container
|
| 25 |
+
docker run -p 7860:7860 marine-species-api
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### 3. HuggingFace Spaces Deployment
|
| 29 |
+
|
| 30 |
+
1. Create a new Space on HuggingFace Hub
|
| 31 |
+
2. Set SDK to "Docker"
|
| 32 |
+
3. Upload all files to your Space repository
|
| 33 |
+
4. The Space will automatically build and deploy
|
| 34 |
+
|
| 35 |
+
## Configuration
|
| 36 |
+
|
| 37 |
+
### Environment Variables
|
| 38 |
+
|
| 39 |
+
Copy `.env.example` to `.env` and modify as needed:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
# Model Configuration
|
| 43 |
+
HUGGINGFACE_REPO=your-username/your-model-repo
|
| 44 |
+
MODEL_NAME=marina-benthic-33k
|
| 45 |
+
|
| 46 |
+
# Performance
|
| 47 |
+
ENABLE_MODEL_WARMUP=true
|
| 48 |
+
MAX_FILE_SIZE=10485760
|
| 49 |
+
|
| 50 |
+
# Server
|
| 51 |
+
HOST=0.0.0.0
|
| 52 |
+
PORT=7860
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### Model Setup
|
| 56 |
+
|
| 57 |
+
The API will automatically download the model from HuggingFace Hub on first startup. Ensure your model repository contains:
|
| 58 |
+
|
| 59 |
+
- `marina-benthic-33k.pt` - The YOLOv5 model file
|
| 60 |
+
- `marina-benthic-33k.names` - Class names file (optional)
|
| 61 |
+
|
| 62 |
+
## Testing
|
| 63 |
+
|
| 64 |
+
### Run All Tests
|
| 65 |
+
```bash
|
| 66 |
+
python run_tests.py
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Test API Manually
|
| 70 |
+
```bash
|
| 71 |
+
# Start the API
|
| 72 |
+
python start_api.py
|
| 73 |
+
|
| 74 |
+
# In another terminal, test the API
|
| 75 |
+
python test_api_simple.py
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Test Specific Endpoints
|
| 79 |
+
```bash
|
| 80 |
+
# Health check
|
| 81 |
+
curl http://localhost:7860/api/v1/health
|
| 82 |
+
|
| 83 |
+
# API info
|
| 84 |
+
curl http://localhost:7860/api/v1/info
|
| 85 |
+
|
| 86 |
+
# List species
|
| 87 |
+
curl http://localhost:7860/api/v1/species
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## API Usage
|
| 91 |
+
|
| 92 |
+
### Detect Marine Species
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
import requests
|
| 96 |
+
import base64
|
| 97 |
+
|
| 98 |
+
# Read and encode image
|
| 99 |
+
with open("marine_image.jpg", "rb") as f:
|
| 100 |
+
image_data = base64.b64encode(f.read()).decode()
|
| 101 |
+
|
| 102 |
+
# Make detection request
|
| 103 |
+
response = requests.post("http://localhost:7860/api/v1/detect", json={
|
| 104 |
+
"image": image_data,
|
| 105 |
+
"confidence_threshold": 0.25,
|
| 106 |
+
"return_annotated_image": True
|
| 107 |
+
})
|
| 108 |
+
|
| 109 |
+
result = response.json()
|
| 110 |
+
print(f"Found {len(result['detections'])} marine species")
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### JavaScript Example
|
| 114 |
+
|
| 115 |
+
```javascript
|
| 116 |
+
// Convert image to base64
|
| 117 |
+
const imageBase64 = await convertImageToBase64(imageFile);
|
| 118 |
+
|
| 119 |
+
// Make detection request
|
| 120 |
+
const response = await fetch('/api/v1/detect', {
|
| 121 |
+
method: 'POST',
|
| 122 |
+
headers: {
|
| 123 |
+
'Content-Type': 'application/json',
|
| 124 |
+
},
|
| 125 |
+
body: JSON.stringify({
|
| 126 |
+
image: imageBase64,
|
| 127 |
+
confidence_threshold: 0.25,
|
| 128 |
+
return_annotated_image: true
|
| 129 |
+
})
|
| 130 |
+
});
|
| 131 |
+
|
| 132 |
+
const result = await response.json();
|
| 133 |
+
console.log('Detections:', result.detections);
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## Performance Optimization
|
| 137 |
+
|
| 138 |
+
### Model Caching
|
| 139 |
+
- Model is loaded once on startup and cached in memory
|
| 140 |
+
- Supports model warmup for faster first inference
|
| 141 |
+
|
| 142 |
+
### Request Caching
|
| 143 |
+
- Optional caching of inference results
|
| 144 |
+
- Configurable TTL and cache size
|
| 145 |
+
|
| 146 |
+
### Monitoring
|
| 147 |
+
- Built-in performance metrics
|
| 148 |
+
- System resource monitoring
|
| 149 |
+
- Request timing and success rates
|
| 150 |
+
|
| 151 |
+
## Troubleshooting
|
| 152 |
+
|
| 153 |
+
### Common Issues
|
| 154 |
+
|
| 155 |
+
1. **Model Download Fails**
|
| 156 |
+
```bash
|
| 157 |
+
# Check repository access
|
| 158 |
+
python -c "from huggingface_hub import list_repo_files; print(list_repo_files('your-repo'))"
|
| 159 |
+
|
| 160 |
+
# Manual download
|
| 161 |
+
python app/utils/model_utils.py --download
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
2. **Out of Memory**
|
| 165 |
+
- Reduce `image_size` parameter
|
| 166 |
+
- Use CPU instead of GPU for inference
|
| 167 |
+
- Increase Docker memory limits
|
| 168 |
+
|
| 169 |
+
3. **Slow Inference**
|
| 170 |
+
- Enable model warmup
|
| 171 |
+
- Use GPU if available
|
| 172 |
+
- Optimize image preprocessing
|
| 173 |
+
|
| 174 |
+
### Health Checks
|
| 175 |
+
|
| 176 |
+
```bash
|
| 177 |
+
# Basic health
|
| 178 |
+
curl http://localhost:7860/health
|
| 179 |
+
|
| 180 |
+
# Detailed health with model status
|
| 181 |
+
curl http://localhost:7860/api/v1/health
|
| 182 |
+
|
| 183 |
+
# Readiness probe (for Kubernetes)
|
| 184 |
+
curl http://localhost:7860/api/v1/ready
|
| 185 |
+
|
| 186 |
+
# Liveness probe
|
| 187 |
+
curl http://localhost:7860/api/v1/live
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
## Production Deployment
|
| 191 |
+
|
| 192 |
+
### Docker Compose
|
| 193 |
+
|
| 194 |
+
```yaml
|
| 195 |
+
version: '3.8'
|
| 196 |
+
services:
|
| 197 |
+
marine-api:
|
| 198 |
+
build: .
|
| 199 |
+
ports:
|
| 200 |
+
- "7860:7860"
|
| 201 |
+
environment:
|
| 202 |
+
- ENABLE_MODEL_WARMUP=true
|
| 203 |
+
healthcheck:
|
| 204 |
+
test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
|
| 205 |
+
interval: 30s
|
| 206 |
+
timeout: 10s
|
| 207 |
+
retries: 3
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
### Kubernetes
|
| 211 |
+
|
| 212 |
+
```yaml
|
| 213 |
+
apiVersion: apps/v1
|
| 214 |
+
kind: Deployment
|
| 215 |
+
metadata:
|
| 216 |
+
name: marine-species-api
|
| 217 |
+
spec:
|
| 218 |
+
replicas: 2
|
| 219 |
+
selector:
|
| 220 |
+
matchLabels:
|
| 221 |
+
app: marine-species-api
|
| 222 |
+
template:
|
| 223 |
+
metadata:
|
| 224 |
+
labels:
|
| 225 |
+
app: marine-species-api
|
| 226 |
+
spec:
|
| 227 |
+
containers:
|
| 228 |
+
- name: api
|
| 229 |
+
image: marine-species-api:latest
|
| 230 |
+
ports:
|
| 231 |
+
- containerPort: 7860
|
| 232 |
+
livenessProbe:
|
| 233 |
+
httpGet:
|
| 234 |
+
path: /api/v1/live
|
| 235 |
+
port: 7860
|
| 236 |
+
initialDelaySeconds: 60
|
| 237 |
+
readinessProbe:
|
| 238 |
+
httpGet:
|
| 239 |
+
path: /api/v1/ready
|
| 240 |
+
port: 7860
|
| 241 |
+
initialDelaySeconds: 30
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
## Security Considerations
|
| 245 |
+
|
| 246 |
+
- Configure CORS appropriately for production
|
| 247 |
+
- Add rate limiting for public APIs
|
| 248 |
+
- Validate and sanitize all inputs
|
| 249 |
+
- Use HTTPS in production
|
| 250 |
+
- Monitor for unusual usage patterns
|
| 251 |
+
|
| 252 |
+
## Monitoring and Logging
|
| 253 |
+
|
| 254 |
+
- Structured logging with configurable levels
|
| 255 |
+
- Performance metrics collection
|
| 256 |
+
- Health check endpoints for monitoring systems
|
| 257 |
+
- Error tracking and alerting
|
Dockerfile
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Marine Species Identification API on HuggingFace Spaces
|
| 2 |
+
# Multi-stage build to handle model downloading with proper permissions
|
| 3 |
+
|
| 4 |
+
# Stage 1: Download models as root
|
| 5 |
+
FROM python:3.10-slim AS model-builder
|
| 6 |
+
|
| 7 |
+
# Install system dependencies for model downloading
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
wget \
|
| 10 |
+
curl \
|
| 11 |
+
git \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Install huggingface_hub for downloading models
|
| 15 |
+
RUN pip install huggingface_hub
|
| 16 |
+
|
| 17 |
+
# Create models directory
|
| 18 |
+
RUN mkdir -p /models
|
| 19 |
+
|
| 20 |
+
# Download models from HuggingFace Hub
|
| 21 |
+
# Note: Replace 'seamo-ai/marina-species-v1' with your actual HF repo
|
| 22 |
+
RUN python -c "\
|
| 23 |
+
from huggingface_hub import hf_hub_download; \
|
| 24 |
+
import os; \
|
| 25 |
+
try: \
|
| 26 |
+
hf_hub_download('seamo-ai/marina-species-v1', 'marina-benthic-33k.pt', local_dir='/models', local_dir_use_symlinks=False); \
|
| 27 |
+
print('Model downloaded successfully'); \
|
| 28 |
+
except Exception as e: \
|
| 29 |
+
print(f'Model download failed: {e}'); \
|
| 30 |
+
# Create a placeholder file to prevent build failure \
|
| 31 |
+
with open('/models/marina-benthic-33k.pt', 'w') as f: \
|
| 32 |
+
f.write('placeholder'); \
|
| 33 |
+
"
|
| 34 |
+
|
| 35 |
+
# Try to download class names file
|
| 36 |
+
RUN python -c "\
|
| 37 |
+
from huggingface_hub import hf_hub_download; \
|
| 38 |
+
try: \
|
| 39 |
+
hf_hub_download('seamo-ai/marina-species-v1', 'marina-benthic-33k.names', local_dir='/models', local_dir_use_symlinks=False); \
|
| 40 |
+
print('Class names downloaded successfully'); \
|
| 41 |
+
except Exception as e: \
|
| 42 |
+
print(f'Class names download failed: {e}'); \
|
| 43 |
+
"
|
| 44 |
+
|
| 45 |
+
# Stage 2: Build the application
|
| 46 |
+
FROM python:3.10-slim
|
| 47 |
+
|
| 48 |
+
# Install system dependencies
|
| 49 |
+
RUN apt-get update && apt-get install -y \
|
| 50 |
+
ffmpeg \
|
| 51 |
+
libsm6 \
|
| 52 |
+
libxext6 \
|
| 53 |
+
libxrender-dev \
|
| 54 |
+
libglib2.0-0 \
|
| 55 |
+
libgomp1 \
|
| 56 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 57 |
+
|
| 58 |
+
# Set up a new user named "user" with user ID 1000
|
| 59 |
+
RUN useradd -m -u 1000 user
|
| 60 |
+
|
| 61 |
+
# Switch to the "user" user
|
| 62 |
+
USER user
|
| 63 |
+
|
| 64 |
+
# Set home to the user's home directory
|
| 65 |
+
ENV HOME=/home/user \
|
| 66 |
+
PATH=/home/user/.local/bin:$PATH
|
| 67 |
+
|
| 68 |
+
# Set the working directory to the user's home directory
|
| 69 |
+
WORKDIR $HOME/app
|
| 70 |
+
|
| 71 |
+
# Set environment variables for HuggingFace and ML libraries
|
| 72 |
+
ENV HF_HUB_OFFLINE=1
|
| 73 |
+
ENV TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 74 |
+
ENV PYTHONPATH=$HOME/app
|
| 75 |
+
ENV TORCH_HOME=$HOME/.cache/torch
|
| 76 |
+
ENV HF_HOME=$HOME/.cache/huggingface
|
| 77 |
+
|
| 78 |
+
# Copy the requirements file and install dependencies
|
| 79 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 80 |
+
RUN pip install --no-cache-dir --upgrade pip
|
| 81 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 82 |
+
|
| 83 |
+
# Copy the downloaded models from the builder stage
|
| 84 |
+
COPY --chown=user --from=model-builder /models $HOME/app/models
|
| 85 |
+
|
| 86 |
+
# Copy the application code
|
| 87 |
+
COPY --chown=user ./app app
|
| 88 |
+
|
| 89 |
+
# Create necessary directories
|
| 90 |
+
RUN mkdir -p $HOME/.cache/huggingface $HOME/.cache/torch
|
| 91 |
+
|
| 92 |
+
# Expose port 7860 (HuggingFace Spaces standard)
|
| 93 |
+
EXPOSE 7860
|
| 94 |
+
|
| 95 |
+
# Health check
|
| 96 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
|
| 97 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 98 |
+
|
| 99 |
+
# Tell uvicorn to run on port 7860, which is the standard for HF Spaces
|
| 100 |
+
# Use 0.0.0.0 to make it accessible from outside the container
|
| 101 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Marine Species Identification API
|
| 3 |
+
|
| 4 |
+
A scalable API for marine species identification using YOLOv5 model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__version__ = "1.0.0"
|
| 8 |
+
__author__ = "Seamo AI"
|
| 9 |
+
__description__ = "Marine Species Identification API"
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API package
|
app/api/dependencies.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI dependencies for the marine species identification API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import HTTPException, status
|
| 6 |
+
from app.services.model_service import model_service
|
| 7 |
+
from app.core.logging import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def get_model_service():
|
| 13 |
+
"""
|
| 14 |
+
Dependency to get the model service.
|
| 15 |
+
Ensures the model is loaded and available.
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
await model_service.ensure_model_available()
|
| 19 |
+
return model_service
|
| 20 |
+
except Exception as e:
|
| 21 |
+
logger.error(f"Failed to initialize model service: {str(e)}")
|
| 22 |
+
raise HTTPException(
|
| 23 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 24 |
+
detail=f"Model service unavailable: {str(e)}"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def validate_model_health():
|
| 29 |
+
"""
|
| 30 |
+
Dependency to validate model health before processing requests.
|
| 31 |
+
"""
|
| 32 |
+
try:
|
| 33 |
+
health_status = await model_service.health_check()
|
| 34 |
+
if not health_status.get("model_loaded", False):
|
| 35 |
+
raise HTTPException(
|
| 36 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 37 |
+
detail="Model is not loaded or unhealthy"
|
| 38 |
+
)
|
| 39 |
+
return True
|
| 40 |
+
except HTTPException:
|
| 41 |
+
raise
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Model health check failed: {str(e)}")
|
| 44 |
+
raise HTTPException(
|
| 45 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 46 |
+
detail="Model health check failed"
|
| 47 |
+
)
|
app/api/v1/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API v1 package
|
app/api/v1/api.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API v1 router configuration.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
from app.api.v1.endpoints import inference, health
|
| 8 |
+
|
| 9 |
+
api_router = APIRouter()
|
| 10 |
+
|
| 11 |
+
# Include endpoint routers
|
| 12 |
+
api_router.include_router(
|
| 13 |
+
inference.router,
|
| 14 |
+
tags=["Marine Species Detection"]
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
api_router.include_router(
|
| 18 |
+
health.router,
|
| 19 |
+
tags=["Health & Status"]
|
| 20 |
+
)
|
app/api/v1/endpoints/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API v1 endpoints
|
app/api/v1/endpoints/health.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Health check and system status endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from fastapi import APIRouter, HTTPException, status
|
| 7 |
+
|
| 8 |
+
from app.models.inference import HealthResponse, APIInfo, ModelInfo, ErrorResponse
|
| 9 |
+
from app.services.model_service import model_service
|
| 10 |
+
from app.core.config import settings
|
| 11 |
+
from app.core.logging import get_logger
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.get(
|
| 19 |
+
"/health",
|
| 20 |
+
response_model=HealthResponse,
|
| 21 |
+
summary="Health Check",
|
| 22 |
+
description="Check the health status of the API and model"
|
| 23 |
+
)
|
| 24 |
+
async def health_check() -> HealthResponse:
|
| 25 |
+
"""
|
| 26 |
+
Perform a comprehensive health check of the API and model.
|
| 27 |
+
|
| 28 |
+
Returns the current status of the API, model loading status, and basic model information.
|
| 29 |
+
This endpoint can be used for monitoring and load balancer health checks.
|
| 30 |
+
"""
|
| 31 |
+
try:
|
| 32 |
+
logger.debug("Performing health check")
|
| 33 |
+
|
| 34 |
+
# Check model health
|
| 35 |
+
health_status = await model_service.health_check()
|
| 36 |
+
|
| 37 |
+
model_info = None
|
| 38 |
+
if health_status.get("model_loaded", False):
|
| 39 |
+
model_info_dict = health_status.get("model_info", {})
|
| 40 |
+
if model_info_dict:
|
| 41 |
+
model_info = ModelInfo(**model_info_dict)
|
| 42 |
+
|
| 43 |
+
return HealthResponse(
|
| 44 |
+
status="healthy" if health_status.get("model_loaded", False) else "degraded",
|
| 45 |
+
model_loaded=health_status.get("model_loaded", False),
|
| 46 |
+
model_info=model_info,
|
| 47 |
+
timestamp=datetime.utcnow().isoformat()
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Health check failed: {str(e)}")
|
| 52 |
+
return HealthResponse(
|
| 53 |
+
status="unhealthy",
|
| 54 |
+
model_loaded=False,
|
| 55 |
+
model_info=None,
|
| 56 |
+
timestamp=datetime.utcnow().isoformat()
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@router.get(
|
| 61 |
+
"/info",
|
| 62 |
+
response_model=APIInfo,
|
| 63 |
+
summary="API Information",
|
| 64 |
+
description="Get comprehensive information about the API and model"
|
| 65 |
+
)
|
| 66 |
+
async def get_api_info() -> APIInfo:
|
| 67 |
+
"""
|
| 68 |
+
Get comprehensive information about the API.
|
| 69 |
+
|
| 70 |
+
Returns detailed information about the API version, capabilities, model information,
|
| 71 |
+
and available endpoints.
|
| 72 |
+
"""
|
| 73 |
+
try:
|
| 74 |
+
logger.debug("Fetching API information")
|
| 75 |
+
|
| 76 |
+
# Get model information
|
| 77 |
+
model_info_dict = model_service.get_model_info()
|
| 78 |
+
model_info = ModelInfo(**model_info_dict)
|
| 79 |
+
|
| 80 |
+
# Define available endpoints
|
| 81 |
+
endpoints = [
|
| 82 |
+
f"{settings.API_V1_STR}/detect",
|
| 83 |
+
f"{settings.API_V1_STR}/species",
|
| 84 |
+
f"{settings.API_V1_STR}/species/{{class_id}}",
|
| 85 |
+
f"{settings.API_V1_STR}/health",
|
| 86 |
+
f"{settings.API_V1_STR}/info"
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
return APIInfo(
|
| 90 |
+
name=settings.PROJECT_NAME,
|
| 91 |
+
version=settings.VERSION,
|
| 92 |
+
description=settings.DESCRIPTION,
|
| 93 |
+
model_info=model_info,
|
| 94 |
+
endpoints=endpoints
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Failed to get API info: {str(e)}")
|
| 99 |
+
raise HTTPException(
|
| 100 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 101 |
+
detail=f"Failed to get API information: {str(e)}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@router.get(
|
| 106 |
+
"/ready",
|
| 107 |
+
summary="Readiness Check",
|
| 108 |
+
description="Check if the API is ready to serve requests"
|
| 109 |
+
)
|
| 110 |
+
async def readiness_check() -> dict:
|
| 111 |
+
"""
|
| 112 |
+
Check if the API is ready to serve requests.
|
| 113 |
+
|
| 114 |
+
This endpoint is specifically designed for Kubernetes readiness probes.
|
| 115 |
+
It returns 200 OK only when the model is loaded and ready to process requests.
|
| 116 |
+
"""
|
| 117 |
+
try:
|
| 118 |
+
health_status = await model_service.health_check()
|
| 119 |
+
|
| 120 |
+
if health_status.get("model_loaded", False):
|
| 121 |
+
return {"status": "ready"}
|
| 122 |
+
else:
|
| 123 |
+
raise HTTPException(
|
| 124 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 125 |
+
detail="Model not ready"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
except HTTPException:
|
| 129 |
+
raise
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Readiness check failed: {str(e)}")
|
| 132 |
+
raise HTTPException(
|
| 133 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 134 |
+
detail="Service not ready"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@router.get(
|
| 139 |
+
"/live",
|
| 140 |
+
summary="Liveness Check",
|
| 141 |
+
description="Check if the API is alive"
|
| 142 |
+
)
|
| 143 |
+
async def liveness_check() -> dict:
|
| 144 |
+
"""
|
| 145 |
+
Check if the API is alive.
|
| 146 |
+
|
| 147 |
+
This endpoint is designed for Kubernetes liveness probes.
|
| 148 |
+
It performs a minimal check to ensure the API process is running.
|
| 149 |
+
"""
|
| 150 |
+
return {"status": "alive", "timestamp": datetime.utcnow().isoformat()}
|
app/api/v1/endpoints/inference.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Marine species detection inference endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException, status, Depends
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
from app.models.inference import (
|
| 9 |
+
InferenceRequest,
|
| 10 |
+
InferenceResponse,
|
| 11 |
+
SpeciesListResponse,
|
| 12 |
+
SpeciesInfo,
|
| 13 |
+
ErrorResponse
|
| 14 |
+
)
|
| 15 |
+
from app.services.inference_service import inference_service
|
| 16 |
+
from app.api.dependencies import validate_model_health
|
| 17 |
+
from app.core.logging import get_logger
|
| 18 |
+
|
| 19 |
+
logger = get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
router = APIRouter()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@router.post(
|
| 25 |
+
"/detect",
|
| 26 |
+
response_model=InferenceResponse,
|
| 27 |
+
responses={
|
| 28 |
+
400: {"model": ErrorResponse, "description": "Bad Request"},
|
| 29 |
+
503: {"model": ErrorResponse, "description": "Service Unavailable"},
|
| 30 |
+
},
|
| 31 |
+
summary="Detect Marine Species",
|
| 32 |
+
description="Detect and identify marine species in an uploaded image using YOLOv5 model"
|
| 33 |
+
)
|
| 34 |
+
async def detect_marine_species(
|
| 35 |
+
request: InferenceRequest,
|
| 36 |
+
_: bool = Depends(validate_model_health)
|
| 37 |
+
) -> InferenceResponse:
|
| 38 |
+
"""
|
| 39 |
+
Detect marine species in an image.
|
| 40 |
+
|
| 41 |
+
- **image**: Base64 encoded image data
|
| 42 |
+
- **confidence_threshold**: Minimum confidence for detections (0.0-1.0)
|
| 43 |
+
- **iou_threshold**: IoU threshold for non-maximum suppression (0.0-1.0)
|
| 44 |
+
- **image_size**: Input image size for inference (320-1280)
|
| 45 |
+
- **return_annotated_image**: Whether to return annotated image with bounding boxes
|
| 46 |
+
- **classes**: Optional list of class IDs to filter detections
|
| 47 |
+
|
| 48 |
+
Returns detection results with bounding boxes, confidence scores, and species names.
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
logger.info("Processing marine species detection request")
|
| 52 |
+
|
| 53 |
+
result = await inference_service.detect_species(
|
| 54 |
+
image_data=request.image,
|
| 55 |
+
confidence_threshold=request.confidence_threshold,
|
| 56 |
+
iou_threshold=request.iou_threshold,
|
| 57 |
+
image_size=request.image_size,
|
| 58 |
+
return_annotated_image=request.return_annotated_image,
|
| 59 |
+
classes=request.classes
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
logger.info(f"Detection completed: {len(result.detections)} species found")
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
except ValueError as e:
|
| 66 |
+
logger.error(f"Invalid input data: {str(e)}")
|
| 67 |
+
raise HTTPException(
|
| 68 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 69 |
+
detail=f"Invalid input: {str(e)}"
|
| 70 |
+
)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Detection failed: {str(e)}")
|
| 73 |
+
raise HTTPException(
|
| 74 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 75 |
+
detail=f"Detection failed: {str(e)}"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@router.get(
|
| 80 |
+
"/species",
|
| 81 |
+
response_model=SpeciesListResponse,
|
| 82 |
+
summary="List Supported Species",
|
| 83 |
+
description="Get a list of all marine species that can be detected by the model"
|
| 84 |
+
)
|
| 85 |
+
async def list_supported_species(
|
| 86 |
+
_: bool = Depends(validate_model_health)
|
| 87 |
+
) -> SpeciesListResponse:
|
| 88 |
+
"""
|
| 89 |
+
Get a list of all supported marine species.
|
| 90 |
+
|
| 91 |
+
Returns a comprehensive list of all marine species that the model can detect,
|
| 92 |
+
including their class IDs and scientific/common names.
|
| 93 |
+
"""
|
| 94 |
+
try:
|
| 95 |
+
logger.info("Fetching supported species list")
|
| 96 |
+
|
| 97 |
+
species_data = await inference_service.get_supported_species()
|
| 98 |
+
|
| 99 |
+
species_list = [
|
| 100 |
+
SpeciesInfo(class_id=item["class_id"], class_name=item["class_name"])
|
| 101 |
+
for item in species_data
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
return SpeciesListResponse(
|
| 105 |
+
species=species_list,
|
| 106 |
+
total_count=len(species_list)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Failed to fetch species list: {str(e)}")
|
| 111 |
+
raise HTTPException(
|
| 112 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 113 |
+
detail=f"Failed to fetch species list: {str(e)}"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@router.get(
|
| 118 |
+
"/species/{class_id}",
|
| 119 |
+
response_model=SpeciesInfo,
|
| 120 |
+
responses={
|
| 121 |
+
404: {"model": ErrorResponse, "description": "Species Not Found"},
|
| 122 |
+
},
|
| 123 |
+
summary="Get Species Information",
|
| 124 |
+
description="Get information about a specific marine species by class ID"
|
| 125 |
+
)
|
| 126 |
+
async def get_species_info(
|
| 127 |
+
class_id: int,
|
| 128 |
+
_: bool = Depends(validate_model_health)
|
| 129 |
+
) -> SpeciesInfo:
|
| 130 |
+
"""
|
| 131 |
+
Get information about a specific marine species.
|
| 132 |
+
|
| 133 |
+
- **class_id**: The class ID of the species to look up
|
| 134 |
+
|
| 135 |
+
Returns detailed information about the specified marine species.
|
| 136 |
+
"""
|
| 137 |
+
try:
|
| 138 |
+
logger.info(f"Fetching species info for class_id: {class_id}")
|
| 139 |
+
|
| 140 |
+
species_data = await inference_service.get_supported_species()
|
| 141 |
+
|
| 142 |
+
# Find the species with the given class_id
|
| 143 |
+
for species in species_data:
|
| 144 |
+
if species["class_id"] == class_id:
|
| 145 |
+
return SpeciesInfo(
|
| 146 |
+
class_id=species["class_id"],
|
| 147 |
+
class_name=species["class_name"]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Species not found
|
| 151 |
+
raise HTTPException(
|
| 152 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 153 |
+
detail=f"Species with class_id {class_id} not found"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
except HTTPException:
|
| 157 |
+
raise
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Failed to fetch species info: {str(e)}")
|
| 160 |
+
raise HTTPException(
|
| 161 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 162 |
+
detail=f"Failed to fetch species info: {str(e)}"
|
| 163 |
+
)
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Core configuration and utilities
|
app/core/config.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for the FastAPI application.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from pydantic import BaseSettings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Settings(BaseSettings):
|
| 11 |
+
"""Application settings."""
|
| 12 |
+
|
| 13 |
+
# API Configuration
|
| 14 |
+
API_V1_STR: str = "/api/v1"
|
| 15 |
+
PROJECT_NAME: str = "Marine Species Identification API"
|
| 16 |
+
VERSION: str = "1.0.0"
|
| 17 |
+
DESCRIPTION: str = "FastAPI-based marine species identification using YOLOv5"
|
| 18 |
+
|
| 19 |
+
# Model Configuration
|
| 20 |
+
MODEL_NAME: str = "marina-benthic-33k"
|
| 21 |
+
MODEL_PATH: str = "models/marina-benthic-33k.pt"
|
| 22 |
+
HUGGINGFACE_REPO: str = "seamo-ai/marina-species-v1"
|
| 23 |
+
DEVICE: Optional[str] = None # Auto-detect if None
|
| 24 |
+
|
| 25 |
+
# Inference Configuration
|
| 26 |
+
DEFAULT_CONFIDENCE_THRESHOLD: float = 0.25
|
| 27 |
+
DEFAULT_IOU_THRESHOLD: float = 0.45
|
| 28 |
+
DEFAULT_IMAGE_SIZE: int = 720
|
| 29 |
+
MAX_IMAGE_SIZE: int = 1280
|
| 30 |
+
MIN_IMAGE_SIZE: int = 320
|
| 31 |
+
|
| 32 |
+
# File Upload Configuration
|
| 33 |
+
MAX_FILE_SIZE: int = 10 * 1024 * 1024 # 10MB
|
| 34 |
+
ALLOWED_EXTENSIONS: set = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
| 35 |
+
|
| 36 |
+
# Performance Configuration
|
| 37 |
+
MODEL_CACHE_SIZE: int = 1
|
| 38 |
+
ENABLE_MODEL_WARMUP: bool = True
|
| 39 |
+
|
| 40 |
+
# HuggingFace Configuration
|
| 41 |
+
HF_HUB_OFFLINE: bool = os.getenv("HF_HUB_OFFLINE", "0") == "1"
|
| 42 |
+
TRANSFORMERS_NO_ADVISORY_WARNINGS: bool = True
|
| 43 |
+
|
| 44 |
+
# Server Configuration
|
| 45 |
+
HOST: str = "0.0.0.0"
|
| 46 |
+
PORT: int = 7860 # HuggingFace Spaces standard
|
| 47 |
+
|
| 48 |
+
class Config:
|
| 49 |
+
env_file = ".env"
|
| 50 |
+
case_sensitive = True
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Global settings instance
|
| 54 |
+
settings = Settings()
|
app/core/logging.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging configuration for the FastAPI application.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def setup_logging(level: str = "INFO") -> None:
|
| 11 |
+
"""
|
| 12 |
+
Setup logging configuration.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 16 |
+
"""
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=getattr(logging, level.upper()),
|
| 19 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 20 |
+
handlers=[
|
| 21 |
+
logging.StreamHandler(sys.stdout)
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Set specific loggers
|
| 26 |
+
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
| 27 |
+
logging.getLogger("fastapi").setLevel(logging.INFO)
|
| 28 |
+
logging.getLogger("yolov5").setLevel(logging.WARNING) # Reduce YOLOv5 verbosity
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_logger(name: str) -> logging.Logger:
|
| 32 |
+
"""
|
| 33 |
+
Get a logger instance.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
name: Logger name
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Logger instance
|
| 40 |
+
"""
|
| 41 |
+
return logging.getLogger(name)
|
app/main.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Marine Species Identification API
|
| 3 |
+
|
| 4 |
+
Main application entry point for the marine species identification API.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.responses import JSONResponse
|
| 12 |
+
import uvicorn
|
| 13 |
+
|
| 14 |
+
from app.core.config import settings
|
| 15 |
+
from app.core.logging import setup_logging, get_logger
|
| 16 |
+
from app.api.v1.api import api_router
|
| 17 |
+
from app.services.model_service import model_service
|
| 18 |
+
|
| 19 |
+
# Setup logging
|
| 20 |
+
setup_logging()
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@asynccontextmanager
|
| 25 |
+
async def lifespan(app: FastAPI):
|
| 26 |
+
"""
|
| 27 |
+
Application lifespan manager.
|
| 28 |
+
Handles startup and shutdown events.
|
| 29 |
+
"""
|
| 30 |
+
# Startup
|
| 31 |
+
logger.info("Starting Marine Species Identification API...")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# Ensure model is available and loaded
|
| 35 |
+
await model_service.ensure_model_available()
|
| 36 |
+
logger.info("Model loaded successfully")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Failed to load model during startup: {str(e)}")
|
| 39 |
+
# Don't fail startup - let health checks handle this
|
| 40 |
+
|
| 41 |
+
logger.info("API startup completed")
|
| 42 |
+
|
| 43 |
+
yield
|
| 44 |
+
|
| 45 |
+
# Shutdown
|
| 46 |
+
logger.info("Shutting down Marine Species Identification API...")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Create FastAPI application
|
| 50 |
+
app = FastAPI(
|
| 51 |
+
title=settings.PROJECT_NAME,
|
| 52 |
+
version=settings.VERSION,
|
| 53 |
+
description=settings.DESCRIPTION,
|
| 54 |
+
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
| 55 |
+
docs_url="/docs",
|
| 56 |
+
redoc_url="/redoc",
|
| 57 |
+
lifespan=lifespan
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Add CORS middleware
|
| 61 |
+
app.add_middleware(
|
| 62 |
+
CORSMiddleware,
|
| 63 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 64 |
+
allow_credentials=True,
|
| 65 |
+
allow_methods=["*"],
|
| 66 |
+
allow_headers=["*"],
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Global exception handler
|
| 71 |
+
@app.exception_handler(Exception)
|
| 72 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 73 |
+
"""Global exception handler for unhandled errors."""
|
| 74 |
+
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
|
| 75 |
+
return JSONResponse(
|
| 76 |
+
status_code=500,
|
| 77 |
+
content={
|
| 78 |
+
"error": "Internal Server Error",
|
| 79 |
+
"message": "An unexpected error occurred",
|
| 80 |
+
"details": str(exc) if settings.DEBUG else None
|
| 81 |
+
}
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Include API router
|
| 86 |
+
app.include_router(api_router, prefix=settings.API_V1_STR)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Root endpoint
|
| 90 |
+
@app.get("/", tags=["Root"])
|
| 91 |
+
async def root():
|
| 92 |
+
"""
|
| 93 |
+
Root endpoint providing basic API information.
|
| 94 |
+
"""
|
| 95 |
+
return {
|
| 96 |
+
"message": "Marine Species Identification API",
|
| 97 |
+
"version": settings.VERSION,
|
| 98 |
+
"docs": "/docs",
|
| 99 |
+
"health": f"{settings.API_V1_STR}/health",
|
| 100 |
+
"api_info": f"{settings.API_V1_STR}/info"
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Health check endpoint at root level (for load balancers)
|
| 105 |
+
@app.get("/health", tags=["Health"])
|
| 106 |
+
async def root_health():
|
| 107 |
+
"""Simple health check at root level."""
|
| 108 |
+
return {"status": "ok"}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
# Run the application
|
| 113 |
+
uvicorn.run(
|
| 114 |
+
"app.main:app",
|
| 115 |
+
host=settings.HOST,
|
| 116 |
+
port=settings.PORT,
|
| 117 |
+
reload=False, # Set to True for development
|
| 118 |
+
log_level="info"
|
| 119 |
+
)
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Pydantic models and YOLOv5 wrapper
|
app/models/inference.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for API requests and responses.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional, Dict, Any
|
| 6 |
+
from pydantic import BaseModel, Field, validator
|
| 7 |
+
import base64
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BoundingBox(BaseModel):
|
| 11 |
+
"""Bounding box coordinates."""
|
| 12 |
+
x: float = Field(..., description="X coordinate of top-left corner")
|
| 13 |
+
y: float = Field(..., description="Y coordinate of top-left corner")
|
| 14 |
+
width: float = Field(..., description="Width of bounding box")
|
| 15 |
+
height: float = Field(..., description="Height of bounding box")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Detection(BaseModel):
|
| 19 |
+
"""Single detection result."""
|
| 20 |
+
class_id: int = Field(..., description="Class ID of detected species")
|
| 21 |
+
class_name: str = Field(..., description="Name of detected marine species")
|
| 22 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence score")
|
| 23 |
+
bbox: BoundingBox = Field(..., description="Bounding box coordinates")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelInfo(BaseModel):
|
| 27 |
+
"""Model information."""
|
| 28 |
+
model_name: str = Field(..., description="Name of the model")
|
| 29 |
+
total_classes: int = Field(..., description="Total number of species classes")
|
| 30 |
+
device: str = Field(..., description="Device used for inference")
|
| 31 |
+
model_path: str = Field(..., description="Path to model file")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class InferenceRequest(BaseModel):
|
| 35 |
+
"""Request model for marine species detection."""
|
| 36 |
+
image: str = Field(..., description="Base64 encoded image data")
|
| 37 |
+
confidence_threshold: float = Field(
|
| 38 |
+
default=0.25,
|
| 39 |
+
ge=0.0,
|
| 40 |
+
le=1.0,
|
| 41 |
+
description="Confidence threshold for detections"
|
| 42 |
+
)
|
| 43 |
+
iou_threshold: float = Field(
|
| 44 |
+
default=0.45,
|
| 45 |
+
ge=0.0,
|
| 46 |
+
le=1.0,
|
| 47 |
+
description="IoU threshold for non-maximum suppression"
|
| 48 |
+
)
|
| 49 |
+
image_size: int = Field(
|
| 50 |
+
default=720,
|
| 51 |
+
ge=320,
|
| 52 |
+
le=1280,
|
| 53 |
+
description="Input image size for inference"
|
| 54 |
+
)
|
| 55 |
+
return_annotated_image: bool = Field(
|
| 56 |
+
default=True,
|
| 57 |
+
description="Whether to return annotated image with detections"
|
| 58 |
+
)
|
| 59 |
+
classes: Optional[List[int]] = Field(
|
| 60 |
+
default=None,
|
| 61 |
+
description="List of class IDs to filter (None for all classes)"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
@validator('image')
|
| 65 |
+
def validate_image(cls, v):
|
| 66 |
+
"""Validate base64 image data."""
|
| 67 |
+
try:
|
| 68 |
+
# Try to decode base64 to ensure it's valid
|
| 69 |
+
base64.b64decode(v)
|
| 70 |
+
return v
|
| 71 |
+
except Exception:
|
| 72 |
+
raise ValueError("Invalid base64 image data")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class InferenceResponse(BaseModel):
|
| 76 |
+
"""Response model for marine species detection."""
|
| 77 |
+
detections: List[Detection] = Field(..., description="List of detected marine species")
|
| 78 |
+
annotated_image: Optional[str] = Field(
|
| 79 |
+
default=None,
|
| 80 |
+
description="Base64 encoded annotated image (if requested)"
|
| 81 |
+
)
|
| 82 |
+
processing_time: float = Field(..., description="Processing time in seconds")
|
| 83 |
+
model_info: ModelInfo = Field(..., description="Information about the model used")
|
| 84 |
+
image_dimensions: Dict[str, int] = Field(
|
| 85 |
+
...,
|
| 86 |
+
description="Original image dimensions (width, height)"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class SpeciesInfo(BaseModel):
|
| 91 |
+
"""Information about a marine species."""
|
| 92 |
+
class_id: int = Field(..., description="Class ID")
|
| 93 |
+
class_name: str = Field(..., description="Species name")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SpeciesListResponse(BaseModel):
|
| 97 |
+
"""Response model for species list endpoint."""
|
| 98 |
+
species: List[SpeciesInfo] = Field(..., description="List of all supported marine species")
|
| 99 |
+
total_count: int = Field(..., description="Total number of species")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class HealthResponse(BaseModel):
|
| 103 |
+
"""Response model for health check."""
|
| 104 |
+
status: str = Field(..., description="API status")
|
| 105 |
+
model_loaded: bool = Field(..., description="Whether the model is loaded")
|
| 106 |
+
model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
|
| 107 |
+
timestamp: str = Field(..., description="Response timestamp")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class ErrorResponse(BaseModel):
|
| 111 |
+
"""Error response model."""
|
| 112 |
+
error: str = Field(..., description="Error type")
|
| 113 |
+
message: str = Field(..., description="Error message")
|
| 114 |
+
details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class APIInfo(BaseModel):
|
| 118 |
+
"""API information response."""
|
| 119 |
+
name: str = Field(..., description="API name")
|
| 120 |
+
version: str = Field(..., description="API version")
|
| 121 |
+
description: str = Field(..., description="API description")
|
| 122 |
+
model_info: ModelInfo = Field(..., description="Model information")
|
| 123 |
+
endpoints: List[str] = Field(..., description="Available endpoints")
|
app/models/yolo.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
YOLOv5 model wrapper adapted from the original Gradio implementation.
|
| 3 |
+
Compatible with the existing marina-benthic-33k.pt model.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import yolov5
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Optional, List, Union, Dict, Any
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from app.core.config import settings
|
| 13 |
+
from app.core.logging import get_logger
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MarineSpeciesYOLO:
|
| 19 |
+
"""
|
| 20 |
+
Wrapper class for loading and running the marine species YOLOv5 model.
|
| 21 |
+
Adapted from the original inference.py to work with FastAPI.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_path: str, device: Optional[str] = None):
|
| 25 |
+
"""
|
| 26 |
+
Initialize the YOLO model.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model_path: Path to the YOLOv5 model file
|
| 30 |
+
device: Device to run inference on ('cpu', 'cuda', etc.)
|
| 31 |
+
"""
|
| 32 |
+
self.model_path = model_path
|
| 33 |
+
self.device = device or self._get_device()
|
| 34 |
+
self.model = None
|
| 35 |
+
self._class_names = None
|
| 36 |
+
|
| 37 |
+
logger.info(f"Initializing MarineSpeciesYOLO with device: {self.device}")
|
| 38 |
+
self._load_model()
|
| 39 |
+
|
| 40 |
+
def _get_device(self) -> str:
|
| 41 |
+
"""Auto-detect the best available device."""
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
return "cuda"
|
| 44 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 45 |
+
return "mps" # Apple Silicon
|
| 46 |
+
else:
|
| 47 |
+
return "cpu"
|
| 48 |
+
|
| 49 |
+
def _load_model(self) -> None:
|
| 50 |
+
"""Load the YOLOv5 model."""
|
| 51 |
+
try:
|
| 52 |
+
if not Path(self.model_path).exists():
|
| 53 |
+
raise FileNotFoundError(f"Model file not found: {self.model_path}")
|
| 54 |
+
|
| 55 |
+
logger.info(f"Loading YOLOv5 model from: {self.model_path}")
|
| 56 |
+
self.model = yolov5.load(self.model_path, device=self.device)
|
| 57 |
+
|
| 58 |
+
# Get class names if available
|
| 59 |
+
if hasattr(self.model, 'names'):
|
| 60 |
+
self._class_names = self.model.names
|
| 61 |
+
logger.info(f"Loaded model with {len(self._class_names)} classes")
|
| 62 |
+
|
| 63 |
+
logger.info("YOLOv5 model loaded successfully")
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Failed to load YOLOv5 model: {str(e)}")
|
| 67 |
+
raise
|
| 68 |
+
|
| 69 |
+
def predict(
|
| 70 |
+
self,
|
| 71 |
+
image: Union[str, np.ndarray],
|
| 72 |
+
conf_threshold: float = 0.25,
|
| 73 |
+
iou_threshold: float = 0.45,
|
| 74 |
+
image_size: int = 720,
|
| 75 |
+
classes: Optional[List[int]] = None
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Run inference on an image.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: Input image (file path or numpy array)
|
| 82 |
+
conf_threshold: Confidence threshold for detections
|
| 83 |
+
iou_threshold: IoU threshold for NMS
|
| 84 |
+
image_size: Input image size for inference
|
| 85 |
+
classes: List of class IDs to filter (None for all classes)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
YOLOv5 detection results
|
| 89 |
+
"""
|
| 90 |
+
if self.model is None:
|
| 91 |
+
raise RuntimeError("Model not loaded")
|
| 92 |
+
|
| 93 |
+
# Set model parameters
|
| 94 |
+
self.model.conf = conf_threshold
|
| 95 |
+
self.model.iou = iou_threshold
|
| 96 |
+
|
| 97 |
+
if classes is not None:
|
| 98 |
+
self.model.classes = classes
|
| 99 |
+
|
| 100 |
+
# Run inference
|
| 101 |
+
try:
|
| 102 |
+
detections = self.model(image, size=image_size)
|
| 103 |
+
return detections
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Inference failed: {str(e)}")
|
| 106 |
+
raise
|
| 107 |
+
|
| 108 |
+
def get_class_names(self) -> Optional[Dict[int, str]]:
|
| 109 |
+
"""Get the class names mapping."""
|
| 110 |
+
return self._class_names
|
| 111 |
+
|
| 112 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 113 |
+
"""Get model information."""
|
| 114 |
+
return {
|
| 115 |
+
"model_path": self.model_path,
|
| 116 |
+
"device": self.device,
|
| 117 |
+
"num_classes": len(self._class_names) if self._class_names else None,
|
| 118 |
+
"class_names": self._class_names
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def warmup(self, image_size: int = 720) -> None:
|
| 122 |
+
"""
|
| 123 |
+
Warm up the model with a dummy inference.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
image_size: Size for warmup inference
|
| 127 |
+
"""
|
| 128 |
+
if self.model is None:
|
| 129 |
+
return
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
logger.info("Warming up model...")
|
| 133 |
+
# Create a dummy image
|
| 134 |
+
dummy_image = np.random.randint(0, 255, (image_size, image_size, 3), dtype=np.uint8)
|
| 135 |
+
self.predict(dummy_image, conf_threshold=0.1)
|
| 136 |
+
logger.info("Model warmup completed")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning(f"Model warmup failed: {str(e)}")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Global model instance (singleton pattern)
|
| 142 |
+
_model_instance: Optional[MarineSpeciesYOLO] = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_model() -> MarineSpeciesYOLO:
|
| 146 |
+
"""
|
| 147 |
+
Get the global model instance (singleton pattern).
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
MarineSpeciesYOLO instance
|
| 151 |
+
"""
|
| 152 |
+
global _model_instance
|
| 153 |
+
|
| 154 |
+
if _model_instance is None:
|
| 155 |
+
_model_instance = MarineSpeciesYOLO(
|
| 156 |
+
model_path=settings.MODEL_PATH,
|
| 157 |
+
device=settings.DEVICE
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Warm up the model if enabled
|
| 161 |
+
if settings.ENABLE_MODEL_WARMUP:
|
| 162 |
+
_model_instance.warmup()
|
| 163 |
+
|
| 164 |
+
return _model_instance
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_class_names(names_file: str) -> Dict[int, str]:
|
| 168 |
+
"""
|
| 169 |
+
Load class names from a .names file.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
names_file: Path to the .names file
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Dictionary mapping class IDs to names
|
| 176 |
+
"""
|
| 177 |
+
class_names = {}
|
| 178 |
+
try:
|
| 179 |
+
with open(names_file, 'r') as f:
|
| 180 |
+
for idx, line in enumerate(f):
|
| 181 |
+
class_names[idx] = line.strip()
|
| 182 |
+
logger.info(f"Loaded {len(class_names)} class names from {names_file}")
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"Failed to load class names: {str(e)}")
|
| 185 |
+
|
| 186 |
+
return class_names
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Services layer for business logic
|
app/services/inference_service.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference service for marine species detection.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
from typing import List, Optional, Dict, Tuple
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
from app.core.config import settings
|
| 14 |
+
from app.core.logging import get_logger
|
| 15 |
+
from app.models.inference import Detection, BoundingBox, InferenceResponse, ModelInfo
|
| 16 |
+
from app.services.model_service import model_service
|
| 17 |
+
from app.utils.image_processing import decode_base64_image, encode_image_to_base64
|
| 18 |
+
|
| 19 |
+
logger = get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class InferenceService:
|
| 23 |
+
"""Service for running marine species detection inference."""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.model_service = model_service
|
| 27 |
+
|
| 28 |
+
async def detect_species(
|
| 29 |
+
self,
|
| 30 |
+
image_data: str,
|
| 31 |
+
confidence_threshold: float = 0.25,
|
| 32 |
+
iou_threshold: float = 0.45,
|
| 33 |
+
image_size: int = 720,
|
| 34 |
+
return_annotated_image: bool = True,
|
| 35 |
+
classes: Optional[List[int]] = None
|
| 36 |
+
) -> InferenceResponse:
|
| 37 |
+
"""
|
| 38 |
+
Detect marine species in an image.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
image_data: Base64 encoded image
|
| 42 |
+
confidence_threshold: Confidence threshold for detections
|
| 43 |
+
iou_threshold: IoU threshold for NMS
|
| 44 |
+
image_size: Input image size for inference
|
| 45 |
+
return_annotated_image: Whether to return annotated image
|
| 46 |
+
classes: List of class IDs to filter
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
InferenceResponse with detection results
|
| 50 |
+
"""
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Decode the image
|
| 55 |
+
image, original_dims = decode_base64_image(image_data)
|
| 56 |
+
logger.info(f"Processing image with dimensions: {original_dims}")
|
| 57 |
+
|
| 58 |
+
# Get the model
|
| 59 |
+
model = self.model_service.get_model()
|
| 60 |
+
|
| 61 |
+
# Run inference
|
| 62 |
+
predictions = model.predict(
|
| 63 |
+
image=image,
|
| 64 |
+
conf_threshold=confidence_threshold,
|
| 65 |
+
iou_threshold=iou_threshold,
|
| 66 |
+
image_size=image_size,
|
| 67 |
+
classes=classes
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Process predictions
|
| 71 |
+
detections = self._process_predictions(predictions)
|
| 72 |
+
|
| 73 |
+
# Generate annotated image if requested
|
| 74 |
+
annotated_image_b64 = None
|
| 75 |
+
if return_annotated_image and detections:
|
| 76 |
+
annotated_image = self._create_annotated_image(image, predictions)
|
| 77 |
+
annotated_image_b64 = encode_image_to_base64(annotated_image)
|
| 78 |
+
|
| 79 |
+
# Get model info
|
| 80 |
+
model_info_dict = self.model_service.get_model_info()
|
| 81 |
+
model_info = ModelInfo(**model_info_dict)
|
| 82 |
+
|
| 83 |
+
processing_time = time.time() - start_time
|
| 84 |
+
|
| 85 |
+
logger.info(f"Inference completed in {processing_time:.3f}s, found {len(detections)} detections")
|
| 86 |
+
|
| 87 |
+
return InferenceResponse(
|
| 88 |
+
detections=detections,
|
| 89 |
+
annotated_image=annotated_image_b64,
|
| 90 |
+
processing_time=processing_time,
|
| 91 |
+
model_info=model_info,
|
| 92 |
+
image_dimensions={"width": original_dims[0], "height": original_dims[1]}
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Inference failed: {str(e)}")
|
| 97 |
+
raise
|
| 98 |
+
|
| 99 |
+
def _process_predictions(self, predictions) -> List[Detection]:
|
| 100 |
+
"""
|
| 101 |
+
Process YOLOv5 predictions into Detection objects.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
predictions: YOLOv5 prediction results
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List of Detection objects
|
| 108 |
+
"""
|
| 109 |
+
detections = []
|
| 110 |
+
class_names = self.model_service.get_class_names()
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
# Get predictions as pandas DataFrame
|
| 114 |
+
pred_df = predictions.pandas().xyxy[0]
|
| 115 |
+
|
| 116 |
+
for _, row in pred_df.iterrows():
|
| 117 |
+
# Extract bounding box coordinates
|
| 118 |
+
x1, y1, x2, y2 = row['xmin'], row['ymin'], row['xmax'], row['ymax']
|
| 119 |
+
width = x2 - x1
|
| 120 |
+
height = y2 - y1
|
| 121 |
+
|
| 122 |
+
# Get class information
|
| 123 |
+
class_id = int(row['class'])
|
| 124 |
+
confidence = float(row['confidence'])
|
| 125 |
+
|
| 126 |
+
# Get class name
|
| 127 |
+
if class_names and class_id in class_names:
|
| 128 |
+
class_name = class_names[class_id]
|
| 129 |
+
else:
|
| 130 |
+
class_name = f"class_{class_id}"
|
| 131 |
+
|
| 132 |
+
# Create detection object
|
| 133 |
+
detection = Detection(
|
| 134 |
+
class_id=class_id,
|
| 135 |
+
class_name=class_name,
|
| 136 |
+
confidence=confidence,
|
| 137 |
+
bbox=BoundingBox(
|
| 138 |
+
x=float(x1),
|
| 139 |
+
y=float(y1),
|
| 140 |
+
width=float(width),
|
| 141 |
+
height=float(height)
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
detections.append(detection)
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Failed to process predictions: {str(e)}")
|
| 149 |
+
raise
|
| 150 |
+
|
| 151 |
+
return detections
|
| 152 |
+
|
| 153 |
+
def _create_annotated_image(self, original_image: np.ndarray, predictions) -> np.ndarray:
|
| 154 |
+
"""
|
| 155 |
+
Create an annotated image with detection boxes and labels.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
original_image: Original input image
|
| 159 |
+
predictions: YOLOv5 prediction results
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Annotated image as numpy array
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
# Use YOLOv5's built-in rendering
|
| 166 |
+
rendered_imgs = predictions.render()
|
| 167 |
+
if rendered_imgs and len(rendered_imgs) > 0:
|
| 168 |
+
return rendered_imgs[0]
|
| 169 |
+
else:
|
| 170 |
+
# Fallback: return original image if rendering fails
|
| 171 |
+
return original_image
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Failed to create annotated image: {str(e)}")
|
| 175 |
+
# Return original image as fallback
|
| 176 |
+
return original_image
|
| 177 |
+
|
| 178 |
+
async def get_supported_species(self) -> List[Dict]:
|
| 179 |
+
"""
|
| 180 |
+
Get list of all supported marine species.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
List of species information
|
| 184 |
+
"""
|
| 185 |
+
class_names = self.model_service.get_class_names()
|
| 186 |
+
|
| 187 |
+
if not class_names:
|
| 188 |
+
return []
|
| 189 |
+
|
| 190 |
+
species_list = []
|
| 191 |
+
for class_id, class_name in class_names.items():
|
| 192 |
+
species_list.append({
|
| 193 |
+
"class_id": class_id,
|
| 194 |
+
"class_name": class_name
|
| 195 |
+
})
|
| 196 |
+
|
| 197 |
+
return sorted(species_list, key=lambda x: x["class_name"])
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# Global service instance
|
| 201 |
+
inference_service = InferenceService()
|
app/services/model_service.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model service for managing YOLOv5 model lifecycle and operations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
from app.core.config import settings
|
| 11 |
+
from app.core.logging import get_logger
|
| 12 |
+
from app.models.yolo import MarineSpeciesYOLO, get_model
|
| 13 |
+
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ModelService:
|
| 18 |
+
"""Service for managing the marine species detection model."""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._model: Optional[MarineSpeciesYOLO] = None
|
| 22 |
+
self._class_names: Optional[Dict[int, str]] = None
|
| 23 |
+
|
| 24 |
+
async def ensure_model_available(self) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Ensure the model is downloaded and available.
|
| 27 |
+
Downloads from HuggingFace Hub if not present locally.
|
| 28 |
+
"""
|
| 29 |
+
model_path = Path(settings.MODEL_PATH)
|
| 30 |
+
|
| 31 |
+
# Check if model exists locally
|
| 32 |
+
if not model_path.exists():
|
| 33 |
+
logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
|
| 34 |
+
await self._download_model()
|
| 35 |
+
|
| 36 |
+
# Load class names if available
|
| 37 |
+
await self._load_class_names()
|
| 38 |
+
|
| 39 |
+
async def _download_model(self) -> None:
|
| 40 |
+
"""Download model from HuggingFace Hub."""
|
| 41 |
+
try:
|
| 42 |
+
# Create models directory if it doesn't exist
|
| 43 |
+
model_dir = Path(settings.MODEL_PATH).parent
|
| 44 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
# Download the model file
|
| 47 |
+
logger.info(f"Downloading model from {settings.HUGGINGFACE_REPO}")
|
| 48 |
+
|
| 49 |
+
# Download the .pt model file
|
| 50 |
+
model_filename = f"{settings.MODEL_NAME}.pt"
|
| 51 |
+
downloaded_path = hf_hub_download(
|
| 52 |
+
repo_id=settings.HUGGINGFACE_REPO,
|
| 53 |
+
filename=model_filename,
|
| 54 |
+
cache_dir=str(model_dir.parent / ".cache"),
|
| 55 |
+
local_dir=str(model_dir),
|
| 56 |
+
local_dir_use_symlinks=False
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
logger.info(f"Model downloaded successfully to: {downloaded_path}")
|
| 60 |
+
|
| 61 |
+
# Also download the .names file if available
|
| 62 |
+
try:
|
| 63 |
+
names_filename = f"{settings.MODEL_NAME}.names"
|
| 64 |
+
names_path = hf_hub_download(
|
| 65 |
+
repo_id=settings.HUGGINGFACE_REPO,
|
| 66 |
+
filename=names_filename,
|
| 67 |
+
cache_dir=str(model_dir.parent / ".cache"),
|
| 68 |
+
local_dir=str(model_dir),
|
| 69 |
+
local_dir_use_symlinks=False
|
| 70 |
+
)
|
| 71 |
+
logger.info(f"Class names file downloaded to: {names_path}")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.warning(f"Could not download .names file: {str(e)}")
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Failed to download model: {str(e)}")
|
| 77 |
+
raise RuntimeError(f"Model download failed: {str(e)}")
|
| 78 |
+
|
| 79 |
+
async def _load_class_names(self) -> None:
|
| 80 |
+
"""Load class names from .names file."""
|
| 81 |
+
names_file = Path(settings.MODEL_PATH).with_suffix('.names')
|
| 82 |
+
|
| 83 |
+
if names_file.exists():
|
| 84 |
+
try:
|
| 85 |
+
class_names = {}
|
| 86 |
+
with open(names_file, 'r') as f:
|
| 87 |
+
for idx, line in enumerate(f):
|
| 88 |
+
class_names[idx] = line.strip()
|
| 89 |
+
|
| 90 |
+
self._class_names = class_names
|
| 91 |
+
logger.info(f"Loaded {len(class_names)} class names")
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"Failed to load class names: {str(e)}")
|
| 94 |
+
else:
|
| 95 |
+
logger.warning(f"Class names file not found: {names_file}")
|
| 96 |
+
|
| 97 |
+
def get_model(self) -> MarineSpeciesYOLO:
|
| 98 |
+
"""
|
| 99 |
+
Get the model instance.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
MarineSpeciesYOLO instance
|
| 103 |
+
"""
|
| 104 |
+
if self._model is None:
|
| 105 |
+
self._model = get_model()
|
| 106 |
+
return self._model
|
| 107 |
+
|
| 108 |
+
def get_class_names(self) -> Optional[Dict[int, str]]:
|
| 109 |
+
"""
|
| 110 |
+
Get class names mapping.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Dictionary mapping class IDs to names
|
| 114 |
+
"""
|
| 115 |
+
if self._class_names is None:
|
| 116 |
+
# Try to get from model
|
| 117 |
+
model = self.get_model()
|
| 118 |
+
self._class_names = model.get_class_names()
|
| 119 |
+
|
| 120 |
+
return self._class_names
|
| 121 |
+
|
| 122 |
+
def get_model_info(self) -> Dict:
|
| 123 |
+
"""
|
| 124 |
+
Get comprehensive model information.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Dictionary with model information
|
| 128 |
+
"""
|
| 129 |
+
model = self.get_model()
|
| 130 |
+
class_names = self.get_class_names()
|
| 131 |
+
|
| 132 |
+
return {
|
| 133 |
+
"model_name": settings.MODEL_NAME,
|
| 134 |
+
"total_classes": len(class_names) if class_names else 0,
|
| 135 |
+
"device": model.device,
|
| 136 |
+
"model_path": settings.MODEL_PATH,
|
| 137 |
+
"huggingface_repo": settings.HUGGINGFACE_REPO
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
async def health_check(self) -> Dict:
|
| 141 |
+
"""
|
| 142 |
+
Perform a health check on the model.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Dictionary with health status
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
model = self.get_model()
|
| 149 |
+
model_info = self.get_model_info()
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"status": "healthy",
|
| 153 |
+
"model_loaded": True,
|
| 154 |
+
"model_info": model_info
|
| 155 |
+
}
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Model health check failed: {str(e)}")
|
| 158 |
+
return {
|
| 159 |
+
"status": "unhealthy",
|
| 160 |
+
"model_loaded": False,
|
| 161 |
+
"error": str(e)
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Global service instance
|
| 166 |
+
model_service = ModelService()
|
app/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Utility functions
|
app/utils/image_processing.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image processing utilities for the FastAPI application.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import io
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
from app.core.config import settings
|
| 13 |
+
from app.core.logging import get_logger
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def decode_base64_image(image_data: str) -> Tuple[np.ndarray, Tuple[int, int]]:
|
| 19 |
+
"""
|
| 20 |
+
Decode base64 image data to numpy array.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
image_data: Base64 encoded image string
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Tuple of (image_array, (width, height))
|
| 27 |
+
"""
|
| 28 |
+
try:
|
| 29 |
+
# Remove data URL prefix if present
|
| 30 |
+
if image_data.startswith('data:image'):
|
| 31 |
+
image_data = image_data.split(',')[1]
|
| 32 |
+
|
| 33 |
+
# Decode base64
|
| 34 |
+
image_bytes = base64.b64decode(image_data)
|
| 35 |
+
|
| 36 |
+
# Open with PIL
|
| 37 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
| 38 |
+
|
| 39 |
+
# Convert to RGB if necessary
|
| 40 |
+
if pil_image.mode != 'RGB':
|
| 41 |
+
pil_image = pil_image.convert('RGB')
|
| 42 |
+
|
| 43 |
+
# Get original dimensions
|
| 44 |
+
original_dims = pil_image.size # (width, height)
|
| 45 |
+
|
| 46 |
+
# Convert to numpy array
|
| 47 |
+
image_array = np.array(pil_image)
|
| 48 |
+
|
| 49 |
+
logger.debug(f"Decoded image with shape: {image_array.shape}")
|
| 50 |
+
|
| 51 |
+
return image_array, original_dims
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Failed to decode base64 image: {str(e)}")
|
| 55 |
+
raise ValueError(f"Invalid image data: {str(e)}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def encode_image_to_base64(image: np.ndarray, format: str = "JPEG", quality: int = 95) -> str:
|
| 59 |
+
"""
|
| 60 |
+
Encode numpy array image to base64 string.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
image: Image as numpy array
|
| 64 |
+
format: Image format (JPEG, PNG, etc.)
|
| 65 |
+
quality: JPEG quality (1-100)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Base64 encoded image string
|
| 69 |
+
"""
|
| 70 |
+
try:
|
| 71 |
+
# Convert numpy array to PIL Image
|
| 72 |
+
if image.dtype != np.uint8:
|
| 73 |
+
image = (image * 255).astype(np.uint8)
|
| 74 |
+
|
| 75 |
+
pil_image = Image.fromarray(image)
|
| 76 |
+
|
| 77 |
+
# Save to bytes buffer
|
| 78 |
+
buffer = io.BytesIO()
|
| 79 |
+
save_kwargs = {"format": format}
|
| 80 |
+
|
| 81 |
+
if format.upper() == "JPEG":
|
| 82 |
+
save_kwargs["quality"] = quality
|
| 83 |
+
save_kwargs["optimize"] = True
|
| 84 |
+
|
| 85 |
+
pil_image.save(buffer, **save_kwargs)
|
| 86 |
+
|
| 87 |
+
# Encode to base64
|
| 88 |
+
image_bytes = buffer.getvalue()
|
| 89 |
+
base64_string = base64.b64encode(image_bytes).decode('utf-8')
|
| 90 |
+
|
| 91 |
+
return base64_string
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"Failed to encode image to base64: {str(e)}")
|
| 95 |
+
raise ValueError(f"Image encoding failed: {str(e)}")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def validate_image_size(image: np.ndarray) -> bool:
|
| 99 |
+
"""
|
| 100 |
+
Validate image dimensions.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
image: Image as numpy array
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
True if image size is valid
|
| 107 |
+
"""
|
| 108 |
+
height, width = image.shape[:2]
|
| 109 |
+
|
| 110 |
+
# Check minimum and maximum dimensions
|
| 111 |
+
min_dim = min(width, height)
|
| 112 |
+
max_dim = max(width, height)
|
| 113 |
+
|
| 114 |
+
if min_dim < 32: # Too small
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
if max_dim > 4096: # Too large
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
return True
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def resize_image_if_needed(image: np.ndarray, max_size: int = 1280) -> np.ndarray:
|
| 124 |
+
"""
|
| 125 |
+
Resize image if it's too large while maintaining aspect ratio.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
image: Image as numpy array
|
| 129 |
+
max_size: Maximum dimension size
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Resized image
|
| 133 |
+
"""
|
| 134 |
+
height, width = image.shape[:2]
|
| 135 |
+
|
| 136 |
+
if max(height, width) <= max_size:
|
| 137 |
+
return image
|
| 138 |
+
|
| 139 |
+
# Calculate new dimensions
|
| 140 |
+
if width > height:
|
| 141 |
+
new_width = max_size
|
| 142 |
+
new_height = int(height * (max_size / width))
|
| 143 |
+
else:
|
| 144 |
+
new_height = max_size
|
| 145 |
+
new_width = int(width * (max_size / height))
|
| 146 |
+
|
| 147 |
+
# Resize using PIL for better quality
|
| 148 |
+
pil_image = Image.fromarray(image)
|
| 149 |
+
resized_pil = pil_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 150 |
+
|
| 151 |
+
return np.array(resized_pil)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def validate_image_format(image_bytes: bytes) -> bool:
|
| 155 |
+
"""
|
| 156 |
+
Validate if the image format is supported.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
image_bytes: Raw image bytes
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
True if format is supported
|
| 163 |
+
"""
|
| 164 |
+
try:
|
| 165 |
+
with Image.open(io.BytesIO(image_bytes)) as img:
|
| 166 |
+
# Check if format is in allowed extensions
|
| 167 |
+
format_lower = img.format.lower() if img.format else ""
|
| 168 |
+
allowed_formats = {"jpeg", "jpg", "png", "bmp", "tiff", "webp"}
|
| 169 |
+
return format_lower in allowed_formats
|
| 170 |
+
except Exception:
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_image_info(image: np.ndarray) -> dict:
|
| 175 |
+
"""
|
| 176 |
+
Get information about an image.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
image: Image as numpy array
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Dictionary with image information
|
| 183 |
+
"""
|
| 184 |
+
height, width = image.shape[:2]
|
| 185 |
+
channels = image.shape[2] if len(image.shape) > 2 else 1
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"width": width,
|
| 189 |
+
"height": height,
|
| 190 |
+
"channels": channels,
|
| 191 |
+
"dtype": str(image.dtype),
|
| 192 |
+
"size_mb": image.nbytes / (1024 * 1024)
|
| 193 |
+
}
|
app/utils/model_utils.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model utilities for downloading and managing the marine species model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Dict, Any
|
| 9 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
| 10 |
+
|
| 11 |
+
from app.core.config import settings
|
| 12 |
+
from app.core.logging import get_logger
|
| 13 |
+
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def download_model_from_hf(
|
| 18 |
+
repo_id: str,
|
| 19 |
+
model_filename: str,
|
| 20 |
+
local_dir: str,
|
| 21 |
+
force_download: bool = False
|
| 22 |
+
) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Download model from HuggingFace Hub.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
repo_id: HuggingFace repository ID
|
| 28 |
+
model_filename: Name of the model file
|
| 29 |
+
local_dir: Local directory to save the model
|
| 30 |
+
force_download: Whether to force re-download if file exists
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Path to the downloaded model file
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
# Create local directory if it doesn't exist
|
| 37 |
+
Path(local_dir).mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
local_path = Path(local_dir) / model_filename
|
| 40 |
+
|
| 41 |
+
# Check if file already exists and force_download is False
|
| 42 |
+
if local_path.exists() and not force_download:
|
| 43 |
+
logger.info(f"Model already exists at {local_path}")
|
| 44 |
+
return str(local_path)
|
| 45 |
+
|
| 46 |
+
logger.info(f"Downloading {model_filename} from {repo_id}...")
|
| 47 |
+
|
| 48 |
+
downloaded_path = hf_hub_download(
|
| 49 |
+
repo_id=repo_id,
|
| 50 |
+
filename=model_filename,
|
| 51 |
+
local_dir=local_dir,
|
| 52 |
+
local_dir_use_symlinks=False,
|
| 53 |
+
force_download=force_download
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
logger.info(f"Model downloaded successfully to: {downloaded_path}")
|
| 57 |
+
return downloaded_path
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Failed to download model: {str(e)}")
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def list_available_files(repo_id: str) -> list:
|
| 65 |
+
"""
|
| 66 |
+
List all available files in a HuggingFace repository.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
repo_id: HuggingFace repository ID
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
List of available files
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
files = list_repo_files(repo_id)
|
| 76 |
+
return files
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"Failed to list repository files: {str(e)}")
|
| 79 |
+
return []
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def verify_model_file(model_path: str) -> bool:
|
| 83 |
+
"""
|
| 84 |
+
Verify that a model file exists and is valid.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model_path: Path to the model file
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
True if model file is valid
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
path = Path(model_path)
|
| 94 |
+
|
| 95 |
+
# Check if file exists
|
| 96 |
+
if not path.exists():
|
| 97 |
+
logger.error(f"Model file does not exist: {model_path}")
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
# Check file size (should be > 1MB for a real model)
|
| 101 |
+
file_size = path.stat().st_size
|
| 102 |
+
if file_size < 1024 * 1024: # 1MB
|
| 103 |
+
logger.warning(f"Model file seems too small: {file_size} bytes")
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
# Check file extension
|
| 107 |
+
if not path.suffix.lower() in ['.pt', '.pth']:
|
| 108 |
+
logger.warning(f"Unexpected model file extension: {path.suffix}")
|
| 109 |
+
|
| 110 |
+
logger.info(f"Model file verified: {model_path} ({file_size / (1024*1024):.1f} MB)")
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Failed to verify model file: {str(e)}")
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_model_info(model_path: str) -> Dict[str, Any]:
|
| 119 |
+
"""
|
| 120 |
+
Get information about a model file.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
model_path: Path to the model file
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Dictionary with model information
|
| 127 |
+
"""
|
| 128 |
+
info = {
|
| 129 |
+
"path": model_path,
|
| 130 |
+
"exists": False,
|
| 131 |
+
"size_mb": 0,
|
| 132 |
+
"size_bytes": 0
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
path = Path(model_path)
|
| 137 |
+
|
| 138 |
+
if path.exists():
|
| 139 |
+
info["exists"] = True
|
| 140 |
+
size_bytes = path.stat().st_size
|
| 141 |
+
info["size_bytes"] = size_bytes
|
| 142 |
+
info["size_mb"] = size_bytes / (1024 * 1024)
|
| 143 |
+
info["modified_time"] = path.stat().st_mtime
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Failed to get model info: {str(e)}")
|
| 147 |
+
|
| 148 |
+
return info
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def cleanup_model_cache(cache_dir: Optional[str] = None) -> None:
|
| 152 |
+
"""
|
| 153 |
+
Clean up model cache directory.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
cache_dir: Cache directory to clean (uses default if None)
|
| 157 |
+
"""
|
| 158 |
+
try:
|
| 159 |
+
if cache_dir is None:
|
| 160 |
+
cache_dir = Path.home() / ".cache" / "huggingface"
|
| 161 |
+
|
| 162 |
+
cache_path = Path(cache_dir)
|
| 163 |
+
|
| 164 |
+
if cache_path.exists():
|
| 165 |
+
logger.info(f"Cleaning up cache directory: {cache_path}")
|
| 166 |
+
shutil.rmtree(cache_path)
|
| 167 |
+
logger.info("Cache cleanup completed")
|
| 168 |
+
else:
|
| 169 |
+
logger.info("Cache directory does not exist")
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Failed to cleanup cache: {str(e)}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def setup_model_directory() -> str:
|
| 176 |
+
"""
|
| 177 |
+
Setup the model directory and ensure it exists.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Path to the model directory
|
| 181 |
+
"""
|
| 182 |
+
model_dir = Path(settings.MODEL_PATH).parent
|
| 183 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
logger.info(f"Model directory setup: {model_dir}")
|
| 186 |
+
return str(model_dir)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
# Command line utility for model management
|
| 191 |
+
import argparse
|
| 192 |
+
|
| 193 |
+
parser = argparse.ArgumentParser(description="Model management utility")
|
| 194 |
+
parser.add_argument("--download", action="store_true", help="Download model from HuggingFace")
|
| 195 |
+
parser.add_argument("--verify", action="store_true", help="Verify model file")
|
| 196 |
+
parser.add_argument("--info", action="store_true", help="Show model information")
|
| 197 |
+
parser.add_argument("--list-files", action="store_true", help="List available files in HF repo")
|
| 198 |
+
parser.add_argument("--cleanup-cache", action="store_true", help="Cleanup model cache")
|
| 199 |
+
parser.add_argument("--force", action="store_true", help="Force download even if file exists")
|
| 200 |
+
|
| 201 |
+
args = parser.parse_args()
|
| 202 |
+
|
| 203 |
+
if args.download:
|
| 204 |
+
setup_model_directory()
|
| 205 |
+
download_model_from_hf(
|
| 206 |
+
repo_id=settings.HUGGINGFACE_REPO,
|
| 207 |
+
model_filename=f"{settings.MODEL_NAME}.pt",
|
| 208 |
+
local_dir=str(Path(settings.MODEL_PATH).parent),
|
| 209 |
+
force_download=args.force
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if args.verify:
|
| 213 |
+
is_valid = verify_model_file(settings.MODEL_PATH)
|
| 214 |
+
print(f"Model valid: {is_valid}")
|
| 215 |
+
|
| 216 |
+
if args.info:
|
| 217 |
+
info = get_model_info(settings.MODEL_PATH)
|
| 218 |
+
print(f"Model info: {info}")
|
| 219 |
+
|
| 220 |
+
if args.list_files:
|
| 221 |
+
files = list_available_files(settings.HUGGINGFACE_REPO)
|
| 222 |
+
print(f"Available files in {settings.HUGGINGFACE_REPO}:")
|
| 223 |
+
for file in files:
|
| 224 |
+
print(f" - {file}")
|
| 225 |
+
|
| 226 |
+
if args.cleanup_cache:
|
| 227 |
+
cleanup_model_cache()
|
app/utils/performance.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Performance monitoring and optimization utilities.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
import psutil
|
| 7 |
+
import functools
|
| 8 |
+
from typing import Dict, Any, Callable, Optional
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
|
| 11 |
+
from app.core.logging import get_logger
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PerformanceMonitor:
|
| 17 |
+
"""Performance monitoring utility."""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.metrics = {
|
| 21 |
+
"requests_total": 0,
|
| 22 |
+
"requests_successful": 0,
|
| 23 |
+
"requests_failed": 0,
|
| 24 |
+
"total_processing_time": 0.0,
|
| 25 |
+
"average_processing_time": 0.0,
|
| 26 |
+
"min_processing_time": float('inf'),
|
| 27 |
+
"max_processing_time": 0.0
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
def record_request(self, processing_time: float, success: bool = True):
|
| 31 |
+
"""Record a request's performance metrics."""
|
| 32 |
+
self.metrics["requests_total"] += 1
|
| 33 |
+
|
| 34 |
+
if success:
|
| 35 |
+
self.metrics["requests_successful"] += 1
|
| 36 |
+
else:
|
| 37 |
+
self.metrics["requests_failed"] += 1
|
| 38 |
+
|
| 39 |
+
self.metrics["total_processing_time"] += processing_time
|
| 40 |
+
self.metrics["average_processing_time"] = (
|
| 41 |
+
self.metrics["total_processing_time"] / self.metrics["requests_total"]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if processing_time < self.metrics["min_processing_time"]:
|
| 45 |
+
self.metrics["min_processing_time"] = processing_time
|
| 46 |
+
|
| 47 |
+
if processing_time > self.metrics["max_processing_time"]:
|
| 48 |
+
self.metrics["max_processing_time"] = processing_time
|
| 49 |
+
|
| 50 |
+
def get_metrics(self) -> Dict[str, Any]:
|
| 51 |
+
"""Get current performance metrics."""
|
| 52 |
+
metrics = self.metrics.copy()
|
| 53 |
+
|
| 54 |
+
# Add system metrics
|
| 55 |
+
try:
|
| 56 |
+
metrics.update({
|
| 57 |
+
"cpu_percent": psutil.cpu_percent(),
|
| 58 |
+
"memory_percent": psutil.virtual_memory().percent,
|
| 59 |
+
"memory_available_mb": psutil.virtual_memory().available / (1024 * 1024)
|
| 60 |
+
})
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.warning(f"Failed to get system metrics: {str(e)}")
|
| 63 |
+
|
| 64 |
+
return metrics
|
| 65 |
+
|
| 66 |
+
def reset_metrics(self):
|
| 67 |
+
"""Reset all metrics."""
|
| 68 |
+
self.metrics = {
|
| 69 |
+
"requests_total": 0,
|
| 70 |
+
"requests_successful": 0,
|
| 71 |
+
"requests_failed": 0,
|
| 72 |
+
"total_processing_time": 0.0,
|
| 73 |
+
"average_processing_time": 0.0,
|
| 74 |
+
"min_processing_time": float('inf'),
|
| 75 |
+
"max_processing_time": 0.0
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Global performance monitor instance
|
| 80 |
+
performance_monitor = PerformanceMonitor()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@contextmanager
|
| 84 |
+
def measure_time():
|
| 85 |
+
"""Context manager to measure execution time."""
|
| 86 |
+
start_time = time.time()
|
| 87 |
+
try:
|
| 88 |
+
yield
|
| 89 |
+
finally:
|
| 90 |
+
end_time = time.time()
|
| 91 |
+
execution_time = end_time - start_time
|
| 92 |
+
logger.debug(f"Execution time: {execution_time:.3f}s")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def timed_function(func: Callable) -> Callable:
|
| 96 |
+
"""Decorator to measure function execution time."""
|
| 97 |
+
@functools.wraps(func)
|
| 98 |
+
def wrapper(*args, **kwargs):
|
| 99 |
+
start_time = time.time()
|
| 100 |
+
try:
|
| 101 |
+
result = func(*args, **kwargs)
|
| 102 |
+
success = True
|
| 103 |
+
return result
|
| 104 |
+
except Exception as e:
|
| 105 |
+
success = False
|
| 106 |
+
raise
|
| 107 |
+
finally:
|
| 108 |
+
end_time = time.time()
|
| 109 |
+
execution_time = end_time - start_time
|
| 110 |
+
performance_monitor.record_request(execution_time, success)
|
| 111 |
+
logger.debug(f"{func.__name__} execution time: {execution_time:.3f}s")
|
| 112 |
+
|
| 113 |
+
return wrapper
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
async def timed_async_function(func: Callable) -> Callable:
|
| 117 |
+
"""Decorator to measure async function execution time."""
|
| 118 |
+
@functools.wraps(func)
|
| 119 |
+
async def wrapper(*args, **kwargs):
|
| 120 |
+
start_time = time.time()
|
| 121 |
+
try:
|
| 122 |
+
result = await func(*args, **kwargs)
|
| 123 |
+
success = True
|
| 124 |
+
return result
|
| 125 |
+
except Exception as e:
|
| 126 |
+
success = False
|
| 127 |
+
raise
|
| 128 |
+
finally:
|
| 129 |
+
end_time = time.time()
|
| 130 |
+
execution_time = end_time - start_time
|
| 131 |
+
performance_monitor.record_request(execution_time, success)
|
| 132 |
+
logger.debug(f"{func.__name__} execution time: {execution_time:.3f}s")
|
| 133 |
+
|
| 134 |
+
return wrapper
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class SimpleCache:
|
| 138 |
+
"""Simple in-memory cache for inference results."""
|
| 139 |
+
|
| 140 |
+
def __init__(self, max_size: int = 100, ttl: int = 3600):
|
| 141 |
+
"""
|
| 142 |
+
Initialize cache.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
max_size: Maximum number of items to cache
|
| 146 |
+
ttl: Time to live in seconds
|
| 147 |
+
"""
|
| 148 |
+
self.max_size = max_size
|
| 149 |
+
self.ttl = ttl
|
| 150 |
+
self.cache = {}
|
| 151 |
+
self.access_times = {}
|
| 152 |
+
|
| 153 |
+
def _is_expired(self, key: str) -> bool:
|
| 154 |
+
"""Check if a cache entry is expired."""
|
| 155 |
+
if key not in self.access_times:
|
| 156 |
+
return True
|
| 157 |
+
|
| 158 |
+
return time.time() - self.access_times[key] > self.ttl
|
| 159 |
+
|
| 160 |
+
def _evict_expired(self):
|
| 161 |
+
"""Remove expired entries."""
|
| 162 |
+
current_time = time.time()
|
| 163 |
+
expired_keys = [
|
| 164 |
+
key for key, access_time in self.access_times.items()
|
| 165 |
+
if current_time - access_time > self.ttl
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
for key in expired_keys:
|
| 169 |
+
self.cache.pop(key, None)
|
| 170 |
+
self.access_times.pop(key, None)
|
| 171 |
+
|
| 172 |
+
def _evict_lru(self):
|
| 173 |
+
"""Remove least recently used entry."""
|
| 174 |
+
if not self.access_times:
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
|
| 178 |
+
self.cache.pop(lru_key, None)
|
| 179 |
+
self.access_times.pop(lru_key, None)
|
| 180 |
+
|
| 181 |
+
def get(self, key: str) -> Optional[Any]:
|
| 182 |
+
"""Get item from cache."""
|
| 183 |
+
if key not in self.cache or self._is_expired(key):
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
self.access_times[key] = time.time()
|
| 187 |
+
return self.cache[key]
|
| 188 |
+
|
| 189 |
+
def set(self, key: str, value: Any):
|
| 190 |
+
"""Set item in cache."""
|
| 191 |
+
# Clean up expired entries
|
| 192 |
+
self._evict_expired()
|
| 193 |
+
|
| 194 |
+
# Evict LRU if at max size
|
| 195 |
+
while len(self.cache) >= self.max_size:
|
| 196 |
+
self._evict_lru()
|
| 197 |
+
|
| 198 |
+
self.cache[key] = value
|
| 199 |
+
self.access_times[key] = time.time()
|
| 200 |
+
|
| 201 |
+
def clear(self):
|
| 202 |
+
"""Clear all cache entries."""
|
| 203 |
+
self.cache.clear()
|
| 204 |
+
self.access_times.clear()
|
| 205 |
+
|
| 206 |
+
def size(self) -> int:
|
| 207 |
+
"""Get current cache size."""
|
| 208 |
+
return len(self.cache)
|
| 209 |
+
|
| 210 |
+
def stats(self) -> Dict[str, Any]:
|
| 211 |
+
"""Get cache statistics."""
|
| 212 |
+
return {
|
| 213 |
+
"size": len(self.cache),
|
| 214 |
+
"max_size": self.max_size,
|
| 215 |
+
"ttl": self.ttl,
|
| 216 |
+
"hit_ratio": getattr(self, '_hits', 0) / max(getattr(self, '_requests', 1), 1)
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Global cache instance
|
| 221 |
+
inference_cache = SimpleCache(max_size=50, ttl=1800) # 30 minutes TTL
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_system_info() -> Dict[str, Any]:
|
| 225 |
+
"""Get system information."""
|
| 226 |
+
try:
|
| 227 |
+
return {
|
| 228 |
+
"cpu_count": psutil.cpu_count(),
|
| 229 |
+
"cpu_percent": psutil.cpu_percent(interval=1),
|
| 230 |
+
"memory_total_gb": psutil.virtual_memory().total / (1024**3),
|
| 231 |
+
"memory_available_gb": psutil.virtual_memory().available / (1024**3),
|
| 232 |
+
"memory_percent": psutil.virtual_memory().percent,
|
| 233 |
+
"disk_usage_percent": psutil.disk_usage('/').percent
|
| 234 |
+
}
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.error(f"Failed to get system info: {str(e)}")
|
| 237 |
+
return {"error": str(e)}
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI and web server
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
|
| 5 |
+
# Machine Learning and Computer Vision
|
| 6 |
+
torch>=1.13.0
|
| 7 |
+
torchvision>=0.14.0
|
| 8 |
+
yolov5==7.0.13
|
| 9 |
+
opencv-python-headless==4.8.1.78
|
| 10 |
+
Pillow==10.1.0
|
| 11 |
+
numpy==1.24.3
|
| 12 |
+
|
| 13 |
+
# Data handling and validation
|
| 14 |
+
pydantic==2.5.0
|
| 15 |
+
pydantic-settings==2.1.0
|
| 16 |
+
|
| 17 |
+
# HuggingFace integration
|
| 18 |
+
huggingface-hub==0.19.4
|
| 19 |
+
|
| 20 |
+
# Utilities
|
| 21 |
+
python-multipart==0.0.6
|
| 22 |
+
aiofiles==23.2.1
|
| 23 |
+
|
| 24 |
+
# Performance monitoring and optimization
|
| 25 |
+
psutil==5.9.6
|
| 26 |
+
|
| 27 |
+
# Testing (optional, for development)
|
| 28 |
+
pytest==7.4.3
|
| 29 |
+
httpx==0.25.2
|
run_tests.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test runner script for the Marine Species Identification API.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def run_pytest():
|
| 12 |
+
"""Run pytest tests."""
|
| 13 |
+
print("🧪 Running pytest tests...")
|
| 14 |
+
try:
|
| 15 |
+
result = subprocess.run([
|
| 16 |
+
sys.executable, "-m", "pytest",
|
| 17 |
+
"tests/",
|
| 18 |
+
"-v",
|
| 19 |
+
"--tb=short"
|
| 20 |
+
], check=True)
|
| 21 |
+
print("✅ All pytest tests passed!")
|
| 22 |
+
return True
|
| 23 |
+
except subprocess.CalledProcessError as e:
|
| 24 |
+
print(f"❌ Some pytest tests failed (exit code: {e.returncode})")
|
| 25 |
+
return False
|
| 26 |
+
except FileNotFoundError:
|
| 27 |
+
print("⚠️ pytest not found, skipping pytest tests")
|
| 28 |
+
return True
|
| 29 |
+
|
| 30 |
+
def run_simple_test():
|
| 31 |
+
"""Run the simple API test."""
|
| 32 |
+
print("🧪 Running simple API test...")
|
| 33 |
+
try:
|
| 34 |
+
result = subprocess.run([
|
| 35 |
+
sys.executable, "test_api_simple.py"
|
| 36 |
+
], check=True)
|
| 37 |
+
print("✅ Simple API test completed!")
|
| 38 |
+
return True
|
| 39 |
+
except subprocess.CalledProcessError as e:
|
| 40 |
+
print(f"❌ Simple API test failed (exit code: {e.returncode})")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
"""Main test runner."""
|
| 45 |
+
print("🐟 Marine Species Identification API - Test Runner")
|
| 46 |
+
print("=" * 55)
|
| 47 |
+
|
| 48 |
+
# Change to project directory
|
| 49 |
+
project_dir = Path(__file__).parent
|
| 50 |
+
os.chdir(project_dir)
|
| 51 |
+
|
| 52 |
+
success = True
|
| 53 |
+
|
| 54 |
+
# Run pytest tests
|
| 55 |
+
if not run_pytest():
|
| 56 |
+
success = False
|
| 57 |
+
|
| 58 |
+
print()
|
| 59 |
+
|
| 60 |
+
# Note about simple test
|
| 61 |
+
print("📝 Note: To run the simple API test, start the API first:")
|
| 62 |
+
print(" python start_api.py")
|
| 63 |
+
print(" # Then in another terminal:")
|
| 64 |
+
print(" python test_api_simple.py")
|
| 65 |
+
|
| 66 |
+
print("=" * 55)
|
| 67 |
+
|
| 68 |
+
if success:
|
| 69 |
+
print("🎉 All available tests completed successfully!")
|
| 70 |
+
return 0
|
| 71 |
+
else:
|
| 72 |
+
print("❌ Some tests failed")
|
| 73 |
+
return 1
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
sys.exit(main())
|
start_api.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Startup script for the Marine Species Identification API.
|
| 4 |
+
This script handles model downloading and API startup.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Add the app directory to Python path
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 14 |
+
|
| 15 |
+
from app.core.config import settings
|
| 16 |
+
from app.core.logging import setup_logging, get_logger
|
| 17 |
+
from app.utils.model_utils import (
|
| 18 |
+
download_model_from_hf,
|
| 19 |
+
verify_model_file,
|
| 20 |
+
setup_model_directory,
|
| 21 |
+
list_available_files
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Setup logging
|
| 25 |
+
setup_logging()
|
| 26 |
+
logger = get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
async def ensure_model_available():
|
| 30 |
+
"""Ensure the model is downloaded and available."""
|
| 31 |
+
logger.info("🔍 Checking model availability...")
|
| 32 |
+
|
| 33 |
+
# Setup model directory
|
| 34 |
+
model_dir = setup_model_directory()
|
| 35 |
+
logger.info(f"Model directory: {model_dir}")
|
| 36 |
+
|
| 37 |
+
# Check if model file exists
|
| 38 |
+
if verify_model_file(settings.MODEL_PATH):
|
| 39 |
+
logger.info("✅ Model file found and verified")
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
logger.info("📥 Model not found locally, attempting to download...")
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
# List available files in the repository
|
| 46 |
+
logger.info(f"Checking repository: {settings.HUGGINGFACE_REPO}")
|
| 47 |
+
available_files = list_available_files(settings.HUGGINGFACE_REPO)
|
| 48 |
+
|
| 49 |
+
if available_files:
|
| 50 |
+
logger.info(f"Available files in repository:")
|
| 51 |
+
for file in available_files[:10]: # Show first 10 files
|
| 52 |
+
logger.info(f" - {file}")
|
| 53 |
+
if len(available_files) > 10:
|
| 54 |
+
logger.info(f" ... and {len(available_files) - 10} more files")
|
| 55 |
+
|
| 56 |
+
# Download the model
|
| 57 |
+
model_filename = f"{settings.MODEL_NAME}.pt"
|
| 58 |
+
|
| 59 |
+
if model_filename in available_files:
|
| 60 |
+
download_model_from_hf(
|
| 61 |
+
repo_id=settings.HUGGINGFACE_REPO,
|
| 62 |
+
model_filename=model_filename,
|
| 63 |
+
local_dir=model_dir,
|
| 64 |
+
force_download=False
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Verify the downloaded model
|
| 68 |
+
if verify_model_file(settings.MODEL_PATH):
|
| 69 |
+
logger.info("✅ Model downloaded and verified successfully")
|
| 70 |
+
return True
|
| 71 |
+
else:
|
| 72 |
+
logger.error("❌ Downloaded model failed verification")
|
| 73 |
+
return False
|
| 74 |
+
else:
|
| 75 |
+
logger.error(f"❌ Model file '{model_filename}' not found in repository")
|
| 76 |
+
logger.info("Available .pt files:")
|
| 77 |
+
pt_files = [f for f in available_files if f.endswith('.pt')]
|
| 78 |
+
for pt_file in pt_files:
|
| 79 |
+
logger.info(f" - {pt_file}")
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"❌ Failed to download model: {str(e)}")
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def start_api():
|
| 88 |
+
"""Start the FastAPI application."""
|
| 89 |
+
import uvicorn
|
| 90 |
+
|
| 91 |
+
logger.info("🚀 Starting Marine Species Identification API...")
|
| 92 |
+
logger.info(f"Host: {settings.HOST}")
|
| 93 |
+
logger.info(f"Port: {settings.PORT}")
|
| 94 |
+
logger.info(f"Docs: http://{settings.HOST}:{settings.PORT}/docs")
|
| 95 |
+
|
| 96 |
+
uvicorn.run(
|
| 97 |
+
"app.main:app",
|
| 98 |
+
host=settings.HOST,
|
| 99 |
+
port=settings.PORT,
|
| 100 |
+
reload=False,
|
| 101 |
+
log_level="info",
|
| 102 |
+
access_log=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
async def main():
|
| 107 |
+
"""Main startup function."""
|
| 108 |
+
logger.info("🐟 Marine Species Identification API Startup")
|
| 109 |
+
logger.info("=" * 50)
|
| 110 |
+
|
| 111 |
+
# Check model availability
|
| 112 |
+
model_available = await ensure_model_available()
|
| 113 |
+
|
| 114 |
+
if not model_available:
|
| 115 |
+
logger.warning("⚠️ Model not available - API will start but inference may fail")
|
| 116 |
+
logger.info("The API will still start and you can check /health for status")
|
| 117 |
+
|
| 118 |
+
logger.info("=" * 50)
|
| 119 |
+
|
| 120 |
+
# Start the API
|
| 121 |
+
start_api()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
try:
|
| 126 |
+
asyncio.run(main())
|
| 127 |
+
except KeyboardInterrupt:
|
| 128 |
+
logger.info("🛑 API startup interrupted by user")
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"❌ Failed to start API: {str(e)}")
|
| 131 |
+
sys.exit(1)
|
test_api_simple.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple test script for the Marine Species Identification API.
|
| 4 |
+
This script can be used to quickly test the API functionality.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import base64
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
import io
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_test_image(width: int = 640, height: int = 480) -> str:
|
| 17 |
+
"""Create a test image and return as base64 string."""
|
| 18 |
+
# Create a simple test image with some patterns
|
| 19 |
+
image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
| 20 |
+
|
| 21 |
+
# Add some simple patterns to make it more interesting
|
| 22 |
+
image[100:200, 100:200] = [255, 0, 0] # Red square
|
| 23 |
+
image[300:400, 300:400] = [0, 255, 0] # Green square
|
| 24 |
+
|
| 25 |
+
pil_image = Image.fromarray(image)
|
| 26 |
+
|
| 27 |
+
# Convert to base64
|
| 28 |
+
buffer = io.BytesIO()
|
| 29 |
+
pil_image.save(buffer, format="JPEG", quality=85)
|
| 30 |
+
image_bytes = buffer.getvalue()
|
| 31 |
+
|
| 32 |
+
return base64.b64encode(image_bytes).decode('utf-8')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_api(base_url: str = "http://localhost:7860"):
|
| 36 |
+
"""Test the API endpoints."""
|
| 37 |
+
|
| 38 |
+
print(f"🧪 Testing Marine Species Identification API at {base_url}")
|
| 39 |
+
print("=" * 60)
|
| 40 |
+
|
| 41 |
+
# Test 1: Root endpoint
|
| 42 |
+
print("1. Testing root endpoint...")
|
| 43 |
+
try:
|
| 44 |
+
response = requests.get(f"{base_url}/")
|
| 45 |
+
print(f" Status: {response.status_code}")
|
| 46 |
+
if response.status_code == 200:
|
| 47 |
+
print(f" Response: {response.json()}")
|
| 48 |
+
print()
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f" Error: {e}")
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
# Test 2: Health check
|
| 54 |
+
print("2. Testing health check...")
|
| 55 |
+
try:
|
| 56 |
+
response = requests.get(f"{base_url}/api/v1/health")
|
| 57 |
+
print(f" Status: {response.status_code}")
|
| 58 |
+
if response.status_code == 200:
|
| 59 |
+
health_data = response.json()
|
| 60 |
+
print(f" API Status: {health_data.get('status')}")
|
| 61 |
+
print(f" Model Loaded: {health_data.get('model_loaded')}")
|
| 62 |
+
print()
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f" Error: {e}")
|
| 65 |
+
print()
|
| 66 |
+
|
| 67 |
+
# Test 3: API info
|
| 68 |
+
print("3. Testing API info...")
|
| 69 |
+
try:
|
| 70 |
+
response = requests.get(f"{base_url}/api/v1/info")
|
| 71 |
+
print(f" Status: {response.status_code}")
|
| 72 |
+
if response.status_code == 200:
|
| 73 |
+
info_data = response.json()
|
| 74 |
+
print(f" API Name: {info_data.get('name')}")
|
| 75 |
+
print(f" Version: {info_data.get('version')}")
|
| 76 |
+
model_info = info_data.get('model_info', {})
|
| 77 |
+
print(f" Model Classes: {model_info.get('total_classes')}")
|
| 78 |
+
print()
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f" Error: {e}")
|
| 81 |
+
print()
|
| 82 |
+
|
| 83 |
+
# Test 4: Species list
|
| 84 |
+
print("4. Testing species list...")
|
| 85 |
+
try:
|
| 86 |
+
response = requests.get(f"{base_url}/api/v1/species")
|
| 87 |
+
print(f" Status: {response.status_code}")
|
| 88 |
+
if response.status_code == 200:
|
| 89 |
+
species_data = response.json()
|
| 90 |
+
total_species = species_data.get('total_count', 0)
|
| 91 |
+
print(f" Total Species: {total_species}")
|
| 92 |
+
if total_species > 0:
|
| 93 |
+
print(f" First 3 species:")
|
| 94 |
+
for species in species_data.get('species', [])[:3]:
|
| 95 |
+
print(f" - {species.get('class_name')} (ID: {species.get('class_id')})")
|
| 96 |
+
print()
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f" Error: {e}")
|
| 99 |
+
print()
|
| 100 |
+
|
| 101 |
+
# Test 5: Detection with test image
|
| 102 |
+
print("5. Testing marine species detection...")
|
| 103 |
+
try:
|
| 104 |
+
# Create a test image
|
| 105 |
+
print(" Creating test image...")
|
| 106 |
+
test_image_b64 = create_test_image()
|
| 107 |
+
|
| 108 |
+
# Prepare request
|
| 109 |
+
detection_request = {
|
| 110 |
+
"image": test_image_b64,
|
| 111 |
+
"confidence_threshold": 0.25,
|
| 112 |
+
"iou_threshold": 0.45,
|
| 113 |
+
"image_size": 640,
|
| 114 |
+
"return_annotated_image": True
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
print(" Sending detection request...")
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
|
| 120 |
+
response = requests.post(
|
| 121 |
+
f"{base_url}/api/v1/detect",
|
| 122 |
+
json=detection_request,
|
| 123 |
+
timeout=30
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
end_time = time.time()
|
| 127 |
+
request_time = end_time - start_time
|
| 128 |
+
|
| 129 |
+
print(f" Status: {response.status_code}")
|
| 130 |
+
print(f" Request Time: {request_time:.2f}s")
|
| 131 |
+
|
| 132 |
+
if response.status_code == 200:
|
| 133 |
+
detection_data = response.json()
|
| 134 |
+
detections = detection_data.get('detections', [])
|
| 135 |
+
processing_time = detection_data.get('processing_time', 0)
|
| 136 |
+
|
| 137 |
+
print(f" Processing Time: {processing_time:.3f}s")
|
| 138 |
+
print(f" Detections Found: {len(detections)}")
|
| 139 |
+
|
| 140 |
+
if detections:
|
| 141 |
+
print(" Top detections:")
|
| 142 |
+
for i, detection in enumerate(detections[:3]):
|
| 143 |
+
print(f" {i+1}. {detection.get('class_name')} "
|
| 144 |
+
f"(confidence: {detection.get('confidence'):.3f})")
|
| 145 |
+
|
| 146 |
+
# Check if annotated image was returned
|
| 147 |
+
if detection_data.get('annotated_image'):
|
| 148 |
+
print(" ✅ Annotated image returned")
|
| 149 |
+
else:
|
| 150 |
+
print(" ❌ No annotated image returned")
|
| 151 |
+
|
| 152 |
+
elif response.status_code == 503:
|
| 153 |
+
print(" ⚠️ Service unavailable (model may not be loaded)")
|
| 154 |
+
else:
|
| 155 |
+
print(f" ❌ Error: {response.text}")
|
| 156 |
+
|
| 157 |
+
print()
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f" Error: {e}")
|
| 161 |
+
print()
|
| 162 |
+
|
| 163 |
+
print("🎉 API testing completed!")
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
import sys
|
| 169 |
+
|
| 170 |
+
# Allow custom base URL
|
| 171 |
+
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
|
| 172 |
+
|
| 173 |
+
test_api(base_url)
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Test package
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic API tests for the Marine Species Identification API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
from fastapi.testclient import TestClient
|
| 11 |
+
|
| 12 |
+
from app.main import app
|
| 13 |
+
|
| 14 |
+
client = TestClient(app)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_test_image(width: int = 640, height: int = 480) -> str:
|
| 18 |
+
"""Create a test image and return as base64 string."""
|
| 19 |
+
# Create a simple test image
|
| 20 |
+
image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
| 21 |
+
pil_image = Image.fromarray(image)
|
| 22 |
+
|
| 23 |
+
# Convert to base64
|
| 24 |
+
buffer = io.BytesIO()
|
| 25 |
+
pil_image.save(buffer, format="JPEG")
|
| 26 |
+
image_bytes = buffer.getvalue()
|
| 27 |
+
|
| 28 |
+
return base64.b64encode(image_bytes).decode('utf-8')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TestHealthEndpoints:
|
| 32 |
+
"""Test health and status endpoints."""
|
| 33 |
+
|
| 34 |
+
def test_root_endpoint(self):
|
| 35 |
+
"""Test root endpoint."""
|
| 36 |
+
response = client.get("/")
|
| 37 |
+
assert response.status_code == 200
|
| 38 |
+
data = response.json()
|
| 39 |
+
assert "message" in data
|
| 40 |
+
assert "version" in data
|
| 41 |
+
|
| 42 |
+
def test_root_health(self):
|
| 43 |
+
"""Test root health endpoint."""
|
| 44 |
+
response = client.get("/health")
|
| 45 |
+
assert response.status_code == 200
|
| 46 |
+
data = response.json()
|
| 47 |
+
assert data["status"] == "ok"
|
| 48 |
+
|
| 49 |
+
def test_health_check(self):
|
| 50 |
+
"""Test detailed health check."""
|
| 51 |
+
response = client.get("/api/v1/health")
|
| 52 |
+
assert response.status_code == 200
|
| 53 |
+
data = response.json()
|
| 54 |
+
assert "status" in data
|
| 55 |
+
assert "model_loaded" in data
|
| 56 |
+
assert "timestamp" in data
|
| 57 |
+
|
| 58 |
+
def test_api_info(self):
|
| 59 |
+
"""Test API info endpoint."""
|
| 60 |
+
response = client.get("/api/v1/info")
|
| 61 |
+
assert response.status_code == 200
|
| 62 |
+
data = response.json()
|
| 63 |
+
assert "name" in data
|
| 64 |
+
assert "version" in data
|
| 65 |
+
assert "endpoints" in data
|
| 66 |
+
|
| 67 |
+
def test_liveness_check(self):
|
| 68 |
+
"""Test liveness probe."""
|
| 69 |
+
response = client.get("/api/v1/live")
|
| 70 |
+
assert response.status_code == 200
|
| 71 |
+
data = response.json()
|
| 72 |
+
assert data["status"] == "alive"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TestSpeciesEndpoints:
|
| 76 |
+
"""Test species-related endpoints."""
|
| 77 |
+
|
| 78 |
+
def test_list_species(self):
|
| 79 |
+
"""Test species list endpoint."""
|
| 80 |
+
response = client.get("/api/v1/species")
|
| 81 |
+
assert response.status_code in [200, 503] # May fail if model not loaded
|
| 82 |
+
|
| 83 |
+
if response.status_code == 200:
|
| 84 |
+
data = response.json()
|
| 85 |
+
assert "species" in data
|
| 86 |
+
assert "total_count" in data
|
| 87 |
+
assert isinstance(data["species"], list)
|
| 88 |
+
|
| 89 |
+
def test_get_species_info(self):
|
| 90 |
+
"""Test individual species info endpoint."""
|
| 91 |
+
# This may fail if model is not loaded, which is expected in test environment
|
| 92 |
+
response = client.get("/api/v1/species/0")
|
| 93 |
+
assert response.status_code in [200, 404, 503]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TestInferenceEndpoints:
|
| 97 |
+
"""Test inference endpoints."""
|
| 98 |
+
|
| 99 |
+
def test_detect_invalid_image(self):
|
| 100 |
+
"""Test detection with invalid image data."""
|
| 101 |
+
response = client.post(
|
| 102 |
+
"/api/v1/detect",
|
| 103 |
+
json={
|
| 104 |
+
"image": "invalid_base64_data",
|
| 105 |
+
"confidence_threshold": 0.25
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
assert response.status_code in [400, 503] # Bad request or service unavailable
|
| 109 |
+
|
| 110 |
+
def test_detect_valid_request_format(self):
|
| 111 |
+
"""Test detection with valid request format."""
|
| 112 |
+
test_image = create_test_image()
|
| 113 |
+
|
| 114 |
+
response = client.post(
|
| 115 |
+
"/api/v1/detect",
|
| 116 |
+
json={
|
| 117 |
+
"image": test_image,
|
| 118 |
+
"confidence_threshold": 0.25,
|
| 119 |
+
"iou_threshold": 0.45,
|
| 120 |
+
"image_size": 640,
|
| 121 |
+
"return_annotated_image": True
|
| 122 |
+
}
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# May return 503 if model is not loaded, which is expected in test environment
|
| 126 |
+
assert response.status_code in [200, 503]
|
| 127 |
+
|
| 128 |
+
if response.status_code == 200:
|
| 129 |
+
data = response.json()
|
| 130 |
+
assert "detections" in data
|
| 131 |
+
assert "processing_time" in data
|
| 132 |
+
assert "model_info" in data
|
| 133 |
+
assert "image_dimensions" in data
|
| 134 |
+
|
| 135 |
+
def test_detect_parameter_validation(self):
|
| 136 |
+
"""Test parameter validation."""
|
| 137 |
+
test_image = create_test_image()
|
| 138 |
+
|
| 139 |
+
# Test invalid confidence threshold
|
| 140 |
+
response = client.post(
|
| 141 |
+
"/api/v1/detect",
|
| 142 |
+
json={
|
| 143 |
+
"image": test_image,
|
| 144 |
+
"confidence_threshold": 1.5 # Invalid: > 1.0
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
assert response.status_code == 422 # Validation error
|
| 148 |
+
|
| 149 |
+
# Test invalid image size
|
| 150 |
+
response = client.post(
|
| 151 |
+
"/api/v1/detect",
|
| 152 |
+
json={
|
| 153 |
+
"image": test_image,
|
| 154 |
+
"image_size": 100 # Invalid: < 320
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
assert response.status_code == 422 # Validation error
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TestErrorHandling:
|
| 161 |
+
"""Test error handling."""
|
| 162 |
+
|
| 163 |
+
def test_404_endpoint(self):
|
| 164 |
+
"""Test non-existent endpoint."""
|
| 165 |
+
response = client.get("/api/v1/nonexistent")
|
| 166 |
+
assert response.status_code == 404
|
| 167 |
+
|
| 168 |
+
def test_method_not_allowed(self):
|
| 169 |
+
"""Test wrong HTTP method."""
|
| 170 |
+
response = client.get("/api/v1/detect") # Should be POST
|
| 171 |
+
assert response.status_code == 405
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
pytest.main([__file__])
|