kamau1 commited on
Commit
bcc2f7b
·
verified ·
1 Parent(s): 90d1de1

Initial commit

Browse files
DEPLOYMENT.md ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🐟 Marine Species Identification API - Deployment Guide
2
+
3
+ ## Quick Start
4
+
5
+ ### 1. Local Development
6
+
7
+ ```bash
8
+ # Install dependencies
9
+ pip install -r requirements.txt
10
+
11
+ # Start the API (with automatic model download)
12
+ python start_api.py
13
+
14
+ # Or start directly with uvicorn
15
+ uvicorn app.main:app --host 0.0.0.0 --port 7860
16
+ ```
17
+
18
+ ### 2. Docker Deployment
19
+
20
+ ```bash
21
+ # Build the Docker image
22
+ docker build -t marine-species-api .
23
+
24
+ # Run the container
25
+ docker run -p 7860:7860 marine-species-api
26
+ ```
27
+
28
+ ### 3. HuggingFace Spaces Deployment
29
+
30
+ 1. Create a new Space on HuggingFace Hub
31
+ 2. Set SDK to "Docker"
32
+ 3. Upload all files to your Space repository
33
+ 4. The Space will automatically build and deploy
34
+
35
+ ## Configuration
36
+
37
+ ### Environment Variables
38
+
39
+ Copy `.env.example` to `.env` and modify as needed:
40
+
41
+ ```bash
42
+ # Model Configuration
43
+ HUGGINGFACE_REPO=your-username/your-model-repo
44
+ MODEL_NAME=marina-benthic-33k
45
+
46
+ # Performance
47
+ ENABLE_MODEL_WARMUP=true
48
+ MAX_FILE_SIZE=10485760
49
+
50
+ # Server
51
+ HOST=0.0.0.0
52
+ PORT=7860
53
+ ```
54
+
55
+ ### Model Setup
56
+
57
+ The API will automatically download the model from HuggingFace Hub on first startup. Ensure your model repository contains:
58
+
59
+ - `marina-benthic-33k.pt` - The YOLOv5 model file
60
+ - `marina-benthic-33k.names` - Class names file (optional)
61
+
62
+ ## Testing
63
+
64
+ ### Run All Tests
65
+ ```bash
66
+ python run_tests.py
67
+ ```
68
+
69
+ ### Test API Manually
70
+ ```bash
71
+ # Start the API
72
+ python start_api.py
73
+
74
+ # In another terminal, test the API
75
+ python test_api_simple.py
76
+ ```
77
+
78
+ ### Test Specific Endpoints
79
+ ```bash
80
+ # Health check
81
+ curl http://localhost:7860/api/v1/health
82
+
83
+ # API info
84
+ curl http://localhost:7860/api/v1/info
85
+
86
+ # List species
87
+ curl http://localhost:7860/api/v1/species
88
+ ```
89
+
90
+ ## API Usage
91
+
92
+ ### Detect Marine Species
93
+
94
+ ```python
95
+ import requests
96
+ import base64
97
+
98
+ # Read and encode image
99
+ with open("marine_image.jpg", "rb") as f:
100
+ image_data = base64.b64encode(f.read()).decode()
101
+
102
+ # Make detection request
103
+ response = requests.post("http://localhost:7860/api/v1/detect", json={
104
+ "image": image_data,
105
+ "confidence_threshold": 0.25,
106
+ "return_annotated_image": True
107
+ })
108
+
109
+ result = response.json()
110
+ print(f"Found {len(result['detections'])} marine species")
111
+ ```
112
+
113
+ ### JavaScript Example
114
+
115
+ ```javascript
116
+ // Convert image to base64
117
+ const imageBase64 = await convertImageToBase64(imageFile);
118
+
119
+ // Make detection request
120
+ const response = await fetch('/api/v1/detect', {
121
+ method: 'POST',
122
+ headers: {
123
+ 'Content-Type': 'application/json',
124
+ },
125
+ body: JSON.stringify({
126
+ image: imageBase64,
127
+ confidence_threshold: 0.25,
128
+ return_annotated_image: true
129
+ })
130
+ });
131
+
132
+ const result = await response.json();
133
+ console.log('Detections:', result.detections);
134
+ ```
135
+
136
+ ## Performance Optimization
137
+
138
+ ### Model Caching
139
+ - Model is loaded once on startup and cached in memory
140
+ - Supports model warmup for faster first inference
141
+
142
+ ### Request Caching
143
+ - Optional caching of inference results
144
+ - Configurable TTL and cache size
145
+
146
+ ### Monitoring
147
+ - Built-in performance metrics
148
+ - System resource monitoring
149
+ - Request timing and success rates
150
+
151
+ ## Troubleshooting
152
+
153
+ ### Common Issues
154
+
155
+ 1. **Model Download Fails**
156
+ ```bash
157
+ # Check repository access
158
+ python -c "from huggingface_hub import list_repo_files; print(list_repo_files('your-repo'))"
159
+
160
+ # Manual download
161
+ python app/utils/model_utils.py --download
162
+ ```
163
+
164
+ 2. **Out of Memory**
165
+ - Reduce `image_size` parameter
166
+ - Use CPU instead of GPU for inference
167
+ - Increase Docker memory limits
168
+
169
+ 3. **Slow Inference**
170
+ - Enable model warmup
171
+ - Use GPU if available
172
+ - Optimize image preprocessing
173
+
174
+ ### Health Checks
175
+
176
+ ```bash
177
+ # Basic health
178
+ curl http://localhost:7860/health
179
+
180
+ # Detailed health with model status
181
+ curl http://localhost:7860/api/v1/health
182
+
183
+ # Readiness probe (for Kubernetes)
184
+ curl http://localhost:7860/api/v1/ready
185
+
186
+ # Liveness probe
187
+ curl http://localhost:7860/api/v1/live
188
+ ```
189
+
190
+ ## Production Deployment
191
+
192
+ ### Docker Compose
193
+
194
+ ```yaml
195
+ version: '3.8'
196
+ services:
197
+ marine-api:
198
+ build: .
199
+ ports:
200
+ - "7860:7860"
201
+ environment:
202
+ - ENABLE_MODEL_WARMUP=true
203
+ healthcheck:
204
+ test: ["CMD", "curl", "-f", "http://localhost:7860/health"]
205
+ interval: 30s
206
+ timeout: 10s
207
+ retries: 3
208
+ ```
209
+
210
+ ### Kubernetes
211
+
212
+ ```yaml
213
+ apiVersion: apps/v1
214
+ kind: Deployment
215
+ metadata:
216
+ name: marine-species-api
217
+ spec:
218
+ replicas: 2
219
+ selector:
220
+ matchLabels:
221
+ app: marine-species-api
222
+ template:
223
+ metadata:
224
+ labels:
225
+ app: marine-species-api
226
+ spec:
227
+ containers:
228
+ - name: api
229
+ image: marine-species-api:latest
230
+ ports:
231
+ - containerPort: 7860
232
+ livenessProbe:
233
+ httpGet:
234
+ path: /api/v1/live
235
+ port: 7860
236
+ initialDelaySeconds: 60
237
+ readinessProbe:
238
+ httpGet:
239
+ path: /api/v1/ready
240
+ port: 7860
241
+ initialDelaySeconds: 30
242
+ ```
243
+
244
+ ## Security Considerations
245
+
246
+ - Configure CORS appropriately for production
247
+ - Add rate limiting for public APIs
248
+ - Validate and sanitize all inputs
249
+ - Use HTTPS in production
250
+ - Monitor for unusual usage patterns
251
+
252
+ ## Monitoring and Logging
253
+
254
+ - Structured logging with configurable levels
255
+ - Performance metrics collection
256
+ - Health check endpoints for monitoring systems
257
+ - Error tracking and alerting
Dockerfile ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Marine Species Identification API on HuggingFace Spaces
2
+ # Multi-stage build to handle model downloading with proper permissions
3
+
4
+ # Stage 1: Download models as root
5
+ FROM python:3.10-slim AS model-builder
6
+
7
+ # Install system dependencies for model downloading
8
+ RUN apt-get update && apt-get install -y \
9
+ wget \
10
+ curl \
11
+ git \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Install huggingface_hub for downloading models
15
+ RUN pip install huggingface_hub
16
+
17
+ # Create models directory
18
+ RUN mkdir -p /models
19
+
20
+ # Download models from HuggingFace Hub
21
+ # Note: Replace 'seamo-ai/marina-species-v1' with your actual HF repo
22
+ RUN python -c "\
23
+ from huggingface_hub import hf_hub_download; \
24
+ import os; \
25
+ try: \
26
+ hf_hub_download('seamo-ai/marina-species-v1', 'marina-benthic-33k.pt', local_dir='/models', local_dir_use_symlinks=False); \
27
+ print('Model downloaded successfully'); \
28
+ except Exception as e: \
29
+ print(f'Model download failed: {e}'); \
30
+ # Create a placeholder file to prevent build failure \
31
+ with open('/models/marina-benthic-33k.pt', 'w') as f: \
32
+ f.write('placeholder'); \
33
+ "
34
+
35
+ # Try to download class names file
36
+ RUN python -c "\
37
+ from huggingface_hub import hf_hub_download; \
38
+ try: \
39
+ hf_hub_download('seamo-ai/marina-species-v1', 'marina-benthic-33k.names', local_dir='/models', local_dir_use_symlinks=False); \
40
+ print('Class names downloaded successfully'); \
41
+ except Exception as e: \
42
+ print(f'Class names download failed: {e}'); \
43
+ "
44
+
45
+ # Stage 2: Build the application
46
+ FROM python:3.10-slim
47
+
48
+ # Install system dependencies
49
+ RUN apt-get update && apt-get install -y \
50
+ ffmpeg \
51
+ libsm6 \
52
+ libxext6 \
53
+ libxrender-dev \
54
+ libglib2.0-0 \
55
+ libgomp1 \
56
+ && rm -rf /var/lib/apt/lists/*
57
+
58
+ # Set up a new user named "user" with user ID 1000
59
+ RUN useradd -m -u 1000 user
60
+
61
+ # Switch to the "user" user
62
+ USER user
63
+
64
+ # Set home to the user's home directory
65
+ ENV HOME=/home/user \
66
+ PATH=/home/user/.local/bin:$PATH
67
+
68
+ # Set the working directory to the user's home directory
69
+ WORKDIR $HOME/app
70
+
71
+ # Set environment variables for HuggingFace and ML libraries
72
+ ENV HF_HUB_OFFLINE=1
73
+ ENV TRANSFORMERS_NO_ADVISORY_WARNINGS=1
74
+ ENV PYTHONPATH=$HOME/app
75
+ ENV TORCH_HOME=$HOME/.cache/torch
76
+ ENV HF_HOME=$HOME/.cache/huggingface
77
+
78
+ # Copy the requirements file and install dependencies
79
+ COPY --chown=user ./requirements.txt requirements.txt
80
+ RUN pip install --no-cache-dir --upgrade pip
81
+ RUN pip install --no-cache-dir --user -r requirements.txt
82
+
83
+ # Copy the downloaded models from the builder stage
84
+ COPY --chown=user --from=model-builder /models $HOME/app/models
85
+
86
+ # Copy the application code
87
+ COPY --chown=user ./app app
88
+
89
+ # Create necessary directories
90
+ RUN mkdir -p $HOME/.cache/huggingface $HOME/.cache/torch
91
+
92
+ # Expose port 7860 (HuggingFace Spaces standard)
93
+ EXPOSE 7860
94
+
95
+ # Health check
96
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
97
+ CMD curl -f http://localhost:7860/health || exit 1
98
+
99
+ # Tell uvicorn to run on port 7860, which is the standard for HF Spaces
100
+ # Use 0.0.0.0 to make it accessible from outside the container
101
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Marine Species Identification API
3
+
4
+ A scalable API for marine species identification using YOLOv5 model.
5
+ """
6
+
7
+ __version__ = "1.0.0"
8
+ __author__ = "Seamo AI"
9
+ __description__ = "Marine Species Identification API"
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API package
app/api/dependencies.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI dependencies for the marine species identification API.
3
+ """
4
+
5
+ from fastapi import HTTPException, status
6
+ from app.services.model_service import model_service
7
+ from app.core.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ async def get_model_service():
13
+ """
14
+ Dependency to get the model service.
15
+ Ensures the model is loaded and available.
16
+ """
17
+ try:
18
+ await model_service.ensure_model_available()
19
+ return model_service
20
+ except Exception as e:
21
+ logger.error(f"Failed to initialize model service: {str(e)}")
22
+ raise HTTPException(
23
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
24
+ detail=f"Model service unavailable: {str(e)}"
25
+ )
26
+
27
+
28
+ async def validate_model_health():
29
+ """
30
+ Dependency to validate model health before processing requests.
31
+ """
32
+ try:
33
+ health_status = await model_service.health_check()
34
+ if not health_status.get("model_loaded", False):
35
+ raise HTTPException(
36
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
37
+ detail="Model is not loaded or unhealthy"
38
+ )
39
+ return True
40
+ except HTTPException:
41
+ raise
42
+ except Exception as e:
43
+ logger.error(f"Model health check failed: {str(e)}")
44
+ raise HTTPException(
45
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
46
+ detail="Model health check failed"
47
+ )
app/api/v1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API v1 package
app/api/v1/api.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API v1 router configuration.
3
+ """
4
+
5
+ from fastapi import APIRouter
6
+
7
+ from app.api.v1.endpoints import inference, health
8
+
9
+ api_router = APIRouter()
10
+
11
+ # Include endpoint routers
12
+ api_router.include_router(
13
+ inference.router,
14
+ tags=["Marine Species Detection"]
15
+ )
16
+
17
+ api_router.include_router(
18
+ health.router,
19
+ tags=["Health & Status"]
20
+ )
app/api/v1/endpoints/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API v1 endpoints
app/api/v1/endpoints/health.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Health check and system status endpoints.
3
+ """
4
+
5
+ from datetime import datetime
6
+ from fastapi import APIRouter, HTTPException, status
7
+
8
+ from app.models.inference import HealthResponse, APIInfo, ModelInfo, ErrorResponse
9
+ from app.services.model_service import model_service
10
+ from app.core.config import settings
11
+ from app.core.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ @router.get(
19
+ "/health",
20
+ response_model=HealthResponse,
21
+ summary="Health Check",
22
+ description="Check the health status of the API and model"
23
+ )
24
+ async def health_check() -> HealthResponse:
25
+ """
26
+ Perform a comprehensive health check of the API and model.
27
+
28
+ Returns the current status of the API, model loading status, and basic model information.
29
+ This endpoint can be used for monitoring and load balancer health checks.
30
+ """
31
+ try:
32
+ logger.debug("Performing health check")
33
+
34
+ # Check model health
35
+ health_status = await model_service.health_check()
36
+
37
+ model_info = None
38
+ if health_status.get("model_loaded", False):
39
+ model_info_dict = health_status.get("model_info", {})
40
+ if model_info_dict:
41
+ model_info = ModelInfo(**model_info_dict)
42
+
43
+ return HealthResponse(
44
+ status="healthy" if health_status.get("model_loaded", False) else "degraded",
45
+ model_loaded=health_status.get("model_loaded", False),
46
+ model_info=model_info,
47
+ timestamp=datetime.utcnow().isoformat()
48
+ )
49
+
50
+ except Exception as e:
51
+ logger.error(f"Health check failed: {str(e)}")
52
+ return HealthResponse(
53
+ status="unhealthy",
54
+ model_loaded=False,
55
+ model_info=None,
56
+ timestamp=datetime.utcnow().isoformat()
57
+ )
58
+
59
+
60
+ @router.get(
61
+ "/info",
62
+ response_model=APIInfo,
63
+ summary="API Information",
64
+ description="Get comprehensive information about the API and model"
65
+ )
66
+ async def get_api_info() -> APIInfo:
67
+ """
68
+ Get comprehensive information about the API.
69
+
70
+ Returns detailed information about the API version, capabilities, model information,
71
+ and available endpoints.
72
+ """
73
+ try:
74
+ logger.debug("Fetching API information")
75
+
76
+ # Get model information
77
+ model_info_dict = model_service.get_model_info()
78
+ model_info = ModelInfo(**model_info_dict)
79
+
80
+ # Define available endpoints
81
+ endpoints = [
82
+ f"{settings.API_V1_STR}/detect",
83
+ f"{settings.API_V1_STR}/species",
84
+ f"{settings.API_V1_STR}/species/{{class_id}}",
85
+ f"{settings.API_V1_STR}/health",
86
+ f"{settings.API_V1_STR}/info"
87
+ ]
88
+
89
+ return APIInfo(
90
+ name=settings.PROJECT_NAME,
91
+ version=settings.VERSION,
92
+ description=settings.DESCRIPTION,
93
+ model_info=model_info,
94
+ endpoints=endpoints
95
+ )
96
+
97
+ except Exception as e:
98
+ logger.error(f"Failed to get API info: {str(e)}")
99
+ raise HTTPException(
100
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
101
+ detail=f"Failed to get API information: {str(e)}"
102
+ )
103
+
104
+
105
+ @router.get(
106
+ "/ready",
107
+ summary="Readiness Check",
108
+ description="Check if the API is ready to serve requests"
109
+ )
110
+ async def readiness_check() -> dict:
111
+ """
112
+ Check if the API is ready to serve requests.
113
+
114
+ This endpoint is specifically designed for Kubernetes readiness probes.
115
+ It returns 200 OK only when the model is loaded and ready to process requests.
116
+ """
117
+ try:
118
+ health_status = await model_service.health_check()
119
+
120
+ if health_status.get("model_loaded", False):
121
+ return {"status": "ready"}
122
+ else:
123
+ raise HTTPException(
124
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
125
+ detail="Model not ready"
126
+ )
127
+
128
+ except HTTPException:
129
+ raise
130
+ except Exception as e:
131
+ logger.error(f"Readiness check failed: {str(e)}")
132
+ raise HTTPException(
133
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
134
+ detail="Service not ready"
135
+ )
136
+
137
+
138
+ @router.get(
139
+ "/live",
140
+ summary="Liveness Check",
141
+ description="Check if the API is alive"
142
+ )
143
+ async def liveness_check() -> dict:
144
+ """
145
+ Check if the API is alive.
146
+
147
+ This endpoint is designed for Kubernetes liveness probes.
148
+ It performs a minimal check to ensure the API process is running.
149
+ """
150
+ return {"status": "alive", "timestamp": datetime.utcnow().isoformat()}
app/api/v1/endpoints/inference.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Marine species detection inference endpoints.
3
+ """
4
+
5
+ from fastapi import APIRouter, HTTPException, status, Depends
6
+ from typing import List
7
+
8
+ from app.models.inference import (
9
+ InferenceRequest,
10
+ InferenceResponse,
11
+ SpeciesListResponse,
12
+ SpeciesInfo,
13
+ ErrorResponse
14
+ )
15
+ from app.services.inference_service import inference_service
16
+ from app.api.dependencies import validate_model_health
17
+ from app.core.logging import get_logger
18
+
19
+ logger = get_logger(__name__)
20
+
21
+ router = APIRouter()
22
+
23
+
24
+ @router.post(
25
+ "/detect",
26
+ response_model=InferenceResponse,
27
+ responses={
28
+ 400: {"model": ErrorResponse, "description": "Bad Request"},
29
+ 503: {"model": ErrorResponse, "description": "Service Unavailable"},
30
+ },
31
+ summary="Detect Marine Species",
32
+ description="Detect and identify marine species in an uploaded image using YOLOv5 model"
33
+ )
34
+ async def detect_marine_species(
35
+ request: InferenceRequest,
36
+ _: bool = Depends(validate_model_health)
37
+ ) -> InferenceResponse:
38
+ """
39
+ Detect marine species in an image.
40
+
41
+ - **image**: Base64 encoded image data
42
+ - **confidence_threshold**: Minimum confidence for detections (0.0-1.0)
43
+ - **iou_threshold**: IoU threshold for non-maximum suppression (0.0-1.0)
44
+ - **image_size**: Input image size for inference (320-1280)
45
+ - **return_annotated_image**: Whether to return annotated image with bounding boxes
46
+ - **classes**: Optional list of class IDs to filter detections
47
+
48
+ Returns detection results with bounding boxes, confidence scores, and species names.
49
+ """
50
+ try:
51
+ logger.info("Processing marine species detection request")
52
+
53
+ result = await inference_service.detect_species(
54
+ image_data=request.image,
55
+ confidence_threshold=request.confidence_threshold,
56
+ iou_threshold=request.iou_threshold,
57
+ image_size=request.image_size,
58
+ return_annotated_image=request.return_annotated_image,
59
+ classes=request.classes
60
+ )
61
+
62
+ logger.info(f"Detection completed: {len(result.detections)} species found")
63
+ return result
64
+
65
+ except ValueError as e:
66
+ logger.error(f"Invalid input data: {str(e)}")
67
+ raise HTTPException(
68
+ status_code=status.HTTP_400_BAD_REQUEST,
69
+ detail=f"Invalid input: {str(e)}"
70
+ )
71
+ except Exception as e:
72
+ logger.error(f"Detection failed: {str(e)}")
73
+ raise HTTPException(
74
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
75
+ detail=f"Detection failed: {str(e)}"
76
+ )
77
+
78
+
79
+ @router.get(
80
+ "/species",
81
+ response_model=SpeciesListResponse,
82
+ summary="List Supported Species",
83
+ description="Get a list of all marine species that can be detected by the model"
84
+ )
85
+ async def list_supported_species(
86
+ _: bool = Depends(validate_model_health)
87
+ ) -> SpeciesListResponse:
88
+ """
89
+ Get a list of all supported marine species.
90
+
91
+ Returns a comprehensive list of all marine species that the model can detect,
92
+ including their class IDs and scientific/common names.
93
+ """
94
+ try:
95
+ logger.info("Fetching supported species list")
96
+
97
+ species_data = await inference_service.get_supported_species()
98
+
99
+ species_list = [
100
+ SpeciesInfo(class_id=item["class_id"], class_name=item["class_name"])
101
+ for item in species_data
102
+ ]
103
+
104
+ return SpeciesListResponse(
105
+ species=species_list,
106
+ total_count=len(species_list)
107
+ )
108
+
109
+ except Exception as e:
110
+ logger.error(f"Failed to fetch species list: {str(e)}")
111
+ raise HTTPException(
112
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
113
+ detail=f"Failed to fetch species list: {str(e)}"
114
+ )
115
+
116
+
117
+ @router.get(
118
+ "/species/{class_id}",
119
+ response_model=SpeciesInfo,
120
+ responses={
121
+ 404: {"model": ErrorResponse, "description": "Species Not Found"},
122
+ },
123
+ summary="Get Species Information",
124
+ description="Get information about a specific marine species by class ID"
125
+ )
126
+ async def get_species_info(
127
+ class_id: int,
128
+ _: bool = Depends(validate_model_health)
129
+ ) -> SpeciesInfo:
130
+ """
131
+ Get information about a specific marine species.
132
+
133
+ - **class_id**: The class ID of the species to look up
134
+
135
+ Returns detailed information about the specified marine species.
136
+ """
137
+ try:
138
+ logger.info(f"Fetching species info for class_id: {class_id}")
139
+
140
+ species_data = await inference_service.get_supported_species()
141
+
142
+ # Find the species with the given class_id
143
+ for species in species_data:
144
+ if species["class_id"] == class_id:
145
+ return SpeciesInfo(
146
+ class_id=species["class_id"],
147
+ class_name=species["class_name"]
148
+ )
149
+
150
+ # Species not found
151
+ raise HTTPException(
152
+ status_code=status.HTTP_404_NOT_FOUND,
153
+ detail=f"Species with class_id {class_id} not found"
154
+ )
155
+
156
+ except HTTPException:
157
+ raise
158
+ except Exception as e:
159
+ logger.error(f"Failed to fetch species info: {str(e)}")
160
+ raise HTTPException(
161
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
162
+ detail=f"Failed to fetch species info: {str(e)}"
163
+ )
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Core configuration and utilities
app/core/config.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration management for the FastAPI application.
3
+ """
4
+
5
+ import os
6
+ from typing import Optional
7
+ from pydantic import BaseSettings
8
+
9
+
10
+ class Settings(BaseSettings):
11
+ """Application settings."""
12
+
13
+ # API Configuration
14
+ API_V1_STR: str = "/api/v1"
15
+ PROJECT_NAME: str = "Marine Species Identification API"
16
+ VERSION: str = "1.0.0"
17
+ DESCRIPTION: str = "FastAPI-based marine species identification using YOLOv5"
18
+
19
+ # Model Configuration
20
+ MODEL_NAME: str = "marina-benthic-33k"
21
+ MODEL_PATH: str = "models/marina-benthic-33k.pt"
22
+ HUGGINGFACE_REPO: str = "seamo-ai/marina-species-v1"
23
+ DEVICE: Optional[str] = None # Auto-detect if None
24
+
25
+ # Inference Configuration
26
+ DEFAULT_CONFIDENCE_THRESHOLD: float = 0.25
27
+ DEFAULT_IOU_THRESHOLD: float = 0.45
28
+ DEFAULT_IMAGE_SIZE: int = 720
29
+ MAX_IMAGE_SIZE: int = 1280
30
+ MIN_IMAGE_SIZE: int = 320
31
+
32
+ # File Upload Configuration
33
+ MAX_FILE_SIZE: int = 10 * 1024 * 1024 # 10MB
34
+ ALLOWED_EXTENSIONS: set = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
35
+
36
+ # Performance Configuration
37
+ MODEL_CACHE_SIZE: int = 1
38
+ ENABLE_MODEL_WARMUP: bool = True
39
+
40
+ # HuggingFace Configuration
41
+ HF_HUB_OFFLINE: bool = os.getenv("HF_HUB_OFFLINE", "0") == "1"
42
+ TRANSFORMERS_NO_ADVISORY_WARNINGS: bool = True
43
+
44
+ # Server Configuration
45
+ HOST: str = "0.0.0.0"
46
+ PORT: int = 7860 # HuggingFace Spaces standard
47
+
48
+ class Config:
49
+ env_file = ".env"
50
+ case_sensitive = True
51
+
52
+
53
+ # Global settings instance
54
+ settings = Settings()
app/core/logging.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging configuration for the FastAPI application.
3
+ """
4
+
5
+ import logging
6
+ import sys
7
+ from typing import Dict, Any
8
+
9
+
10
+ def setup_logging(level: str = "INFO") -> None:
11
+ """
12
+ Setup logging configuration.
13
+
14
+ Args:
15
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
16
+ """
17
+ logging.basicConfig(
18
+ level=getattr(logging, level.upper()),
19
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
20
+ handlers=[
21
+ logging.StreamHandler(sys.stdout)
22
+ ]
23
+ )
24
+
25
+ # Set specific loggers
26
+ logging.getLogger("uvicorn").setLevel(logging.INFO)
27
+ logging.getLogger("fastapi").setLevel(logging.INFO)
28
+ logging.getLogger("yolov5").setLevel(logging.WARNING) # Reduce YOLOv5 verbosity
29
+
30
+
31
+ def get_logger(name: str) -> logging.Logger:
32
+ """
33
+ Get a logger instance.
34
+
35
+ Args:
36
+ name: Logger name
37
+
38
+ Returns:
39
+ Logger instance
40
+ """
41
+ return logging.getLogger(name)
app/main.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Marine Species Identification API
3
+
4
+ Main application entry point for the marine species identification API.
5
+ """
6
+
7
+ import asyncio
8
+ from contextlib import asynccontextmanager
9
+ from fastapi import FastAPI, HTTPException, Request
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import JSONResponse
12
+ import uvicorn
13
+
14
+ from app.core.config import settings
15
+ from app.core.logging import setup_logging, get_logger
16
+ from app.api.v1.api import api_router
17
+ from app.services.model_service import model_service
18
+
19
+ # Setup logging
20
+ setup_logging()
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ """
27
+ Application lifespan manager.
28
+ Handles startup and shutdown events.
29
+ """
30
+ # Startup
31
+ logger.info("Starting Marine Species Identification API...")
32
+
33
+ try:
34
+ # Ensure model is available and loaded
35
+ await model_service.ensure_model_available()
36
+ logger.info("Model loaded successfully")
37
+ except Exception as e:
38
+ logger.error(f"Failed to load model during startup: {str(e)}")
39
+ # Don't fail startup - let health checks handle this
40
+
41
+ logger.info("API startup completed")
42
+
43
+ yield
44
+
45
+ # Shutdown
46
+ logger.info("Shutting down Marine Species Identification API...")
47
+
48
+
49
+ # Create FastAPI application
50
+ app = FastAPI(
51
+ title=settings.PROJECT_NAME,
52
+ version=settings.VERSION,
53
+ description=settings.DESCRIPTION,
54
+ openapi_url=f"{settings.API_V1_STR}/openapi.json",
55
+ docs_url="/docs",
56
+ redoc_url="/redoc",
57
+ lifespan=lifespan
58
+ )
59
+
60
+ # Add CORS middleware
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=["*"], # Configure appropriately for production
64
+ allow_credentials=True,
65
+ allow_methods=["*"],
66
+ allow_headers=["*"],
67
+ )
68
+
69
+
70
+ # Global exception handler
71
+ @app.exception_handler(Exception)
72
+ async def global_exception_handler(request: Request, exc: Exception):
73
+ """Global exception handler for unhandled errors."""
74
+ logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
75
+ return JSONResponse(
76
+ status_code=500,
77
+ content={
78
+ "error": "Internal Server Error",
79
+ "message": "An unexpected error occurred",
80
+ "details": str(exc) if settings.DEBUG else None
81
+ }
82
+ )
83
+
84
+
85
+ # Include API router
86
+ app.include_router(api_router, prefix=settings.API_V1_STR)
87
+
88
+
89
+ # Root endpoint
90
+ @app.get("/", tags=["Root"])
91
+ async def root():
92
+ """
93
+ Root endpoint providing basic API information.
94
+ """
95
+ return {
96
+ "message": "Marine Species Identification API",
97
+ "version": settings.VERSION,
98
+ "docs": "/docs",
99
+ "health": f"{settings.API_V1_STR}/health",
100
+ "api_info": f"{settings.API_V1_STR}/info"
101
+ }
102
+
103
+
104
+ # Health check endpoint at root level (for load balancers)
105
+ @app.get("/health", tags=["Health"])
106
+ async def root_health():
107
+ """Simple health check at root level."""
108
+ return {"status": "ok"}
109
+
110
+
111
+ if __name__ == "__main__":
112
+ # Run the application
113
+ uvicorn.run(
114
+ "app.main:app",
115
+ host=settings.HOST,
116
+ port=settings.PORT,
117
+ reload=False, # Set to True for development
118
+ log_level="info"
119
+ )
app/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Pydantic models and YOLOv5 wrapper
app/models/inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for API requests and responses.
3
+ """
4
+
5
+ from typing import List, Optional, Dict, Any
6
+ from pydantic import BaseModel, Field, validator
7
+ import base64
8
+
9
+
10
+ class BoundingBox(BaseModel):
11
+ """Bounding box coordinates."""
12
+ x: float = Field(..., description="X coordinate of top-left corner")
13
+ y: float = Field(..., description="Y coordinate of top-left corner")
14
+ width: float = Field(..., description="Width of bounding box")
15
+ height: float = Field(..., description="Height of bounding box")
16
+
17
+
18
+ class Detection(BaseModel):
19
+ """Single detection result."""
20
+ class_id: int = Field(..., description="Class ID of detected species")
21
+ class_name: str = Field(..., description="Name of detected marine species")
22
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence score")
23
+ bbox: BoundingBox = Field(..., description="Bounding box coordinates")
24
+
25
+
26
+ class ModelInfo(BaseModel):
27
+ """Model information."""
28
+ model_name: str = Field(..., description="Name of the model")
29
+ total_classes: int = Field(..., description="Total number of species classes")
30
+ device: str = Field(..., description="Device used for inference")
31
+ model_path: str = Field(..., description="Path to model file")
32
+
33
+
34
+ class InferenceRequest(BaseModel):
35
+ """Request model for marine species detection."""
36
+ image: str = Field(..., description="Base64 encoded image data")
37
+ confidence_threshold: float = Field(
38
+ default=0.25,
39
+ ge=0.0,
40
+ le=1.0,
41
+ description="Confidence threshold for detections"
42
+ )
43
+ iou_threshold: float = Field(
44
+ default=0.45,
45
+ ge=0.0,
46
+ le=1.0,
47
+ description="IoU threshold for non-maximum suppression"
48
+ )
49
+ image_size: int = Field(
50
+ default=720,
51
+ ge=320,
52
+ le=1280,
53
+ description="Input image size for inference"
54
+ )
55
+ return_annotated_image: bool = Field(
56
+ default=True,
57
+ description="Whether to return annotated image with detections"
58
+ )
59
+ classes: Optional[List[int]] = Field(
60
+ default=None,
61
+ description="List of class IDs to filter (None for all classes)"
62
+ )
63
+
64
+ @validator('image')
65
+ def validate_image(cls, v):
66
+ """Validate base64 image data."""
67
+ try:
68
+ # Try to decode base64 to ensure it's valid
69
+ base64.b64decode(v)
70
+ return v
71
+ except Exception:
72
+ raise ValueError("Invalid base64 image data")
73
+
74
+
75
+ class InferenceResponse(BaseModel):
76
+ """Response model for marine species detection."""
77
+ detections: List[Detection] = Field(..., description="List of detected marine species")
78
+ annotated_image: Optional[str] = Field(
79
+ default=None,
80
+ description="Base64 encoded annotated image (if requested)"
81
+ )
82
+ processing_time: float = Field(..., description="Processing time in seconds")
83
+ model_info: ModelInfo = Field(..., description="Information about the model used")
84
+ image_dimensions: Dict[str, int] = Field(
85
+ ...,
86
+ description="Original image dimensions (width, height)"
87
+ )
88
+
89
+
90
+ class SpeciesInfo(BaseModel):
91
+ """Information about a marine species."""
92
+ class_id: int = Field(..., description="Class ID")
93
+ class_name: str = Field(..., description="Species name")
94
+
95
+
96
+ class SpeciesListResponse(BaseModel):
97
+ """Response model for species list endpoint."""
98
+ species: List[SpeciesInfo] = Field(..., description="List of all supported marine species")
99
+ total_count: int = Field(..., description="Total number of species")
100
+
101
+
102
+ class HealthResponse(BaseModel):
103
+ """Response model for health check."""
104
+ status: str = Field(..., description="API status")
105
+ model_loaded: bool = Field(..., description="Whether the model is loaded")
106
+ model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
107
+ timestamp: str = Field(..., description="Response timestamp")
108
+
109
+
110
+ class ErrorResponse(BaseModel):
111
+ """Error response model."""
112
+ error: str = Field(..., description="Error type")
113
+ message: str = Field(..., description="Error message")
114
+ details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details")
115
+
116
+
117
+ class APIInfo(BaseModel):
118
+ """API information response."""
119
+ name: str = Field(..., description="API name")
120
+ version: str = Field(..., description="API version")
121
+ description: str = Field(..., description="API description")
122
+ model_info: ModelInfo = Field(..., description="Model information")
123
+ endpoints: List[str] = Field(..., description="Available endpoints")
app/models/yolo.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ YOLOv5 model wrapper adapted from the original Gradio implementation.
3
+ Compatible with the existing marina-benthic-33k.pt model.
4
+ """
5
+
6
+ import torch
7
+ import yolov5
8
+ import numpy as np
9
+ from typing import Optional, List, Union, Dict, Any
10
+ from pathlib import Path
11
+
12
+ from app.core.config import settings
13
+ from app.core.logging import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class MarineSpeciesYOLO:
19
+ """
20
+ Wrapper class for loading and running the marine species YOLOv5 model.
21
+ Adapted from the original inference.py to work with FastAPI.
22
+ """
23
+
24
+ def __init__(self, model_path: str, device: Optional[str] = None):
25
+ """
26
+ Initialize the YOLO model.
27
+
28
+ Args:
29
+ model_path: Path to the YOLOv5 model file
30
+ device: Device to run inference on ('cpu', 'cuda', etc.)
31
+ """
32
+ self.model_path = model_path
33
+ self.device = device or self._get_device()
34
+ self.model = None
35
+ self._class_names = None
36
+
37
+ logger.info(f"Initializing MarineSpeciesYOLO with device: {self.device}")
38
+ self._load_model()
39
+
40
+ def _get_device(self) -> str:
41
+ """Auto-detect the best available device."""
42
+ if torch.cuda.is_available():
43
+ return "cuda"
44
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
45
+ return "mps" # Apple Silicon
46
+ else:
47
+ return "cpu"
48
+
49
+ def _load_model(self) -> None:
50
+ """Load the YOLOv5 model."""
51
+ try:
52
+ if not Path(self.model_path).exists():
53
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
54
+
55
+ logger.info(f"Loading YOLOv5 model from: {self.model_path}")
56
+ self.model = yolov5.load(self.model_path, device=self.device)
57
+
58
+ # Get class names if available
59
+ if hasattr(self.model, 'names'):
60
+ self._class_names = self.model.names
61
+ logger.info(f"Loaded model with {len(self._class_names)} classes")
62
+
63
+ logger.info("YOLOv5 model loaded successfully")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Failed to load YOLOv5 model: {str(e)}")
67
+ raise
68
+
69
+ def predict(
70
+ self,
71
+ image: Union[str, np.ndarray],
72
+ conf_threshold: float = 0.25,
73
+ iou_threshold: float = 0.45,
74
+ image_size: int = 720,
75
+ classes: Optional[List[int]] = None
76
+ ) -> torch.Tensor:
77
+ """
78
+ Run inference on an image.
79
+
80
+ Args:
81
+ image: Input image (file path or numpy array)
82
+ conf_threshold: Confidence threshold for detections
83
+ iou_threshold: IoU threshold for NMS
84
+ image_size: Input image size for inference
85
+ classes: List of class IDs to filter (None for all classes)
86
+
87
+ Returns:
88
+ YOLOv5 detection results
89
+ """
90
+ if self.model is None:
91
+ raise RuntimeError("Model not loaded")
92
+
93
+ # Set model parameters
94
+ self.model.conf = conf_threshold
95
+ self.model.iou = iou_threshold
96
+
97
+ if classes is not None:
98
+ self.model.classes = classes
99
+
100
+ # Run inference
101
+ try:
102
+ detections = self.model(image, size=image_size)
103
+ return detections
104
+ except Exception as e:
105
+ logger.error(f"Inference failed: {str(e)}")
106
+ raise
107
+
108
+ def get_class_names(self) -> Optional[Dict[int, str]]:
109
+ """Get the class names mapping."""
110
+ return self._class_names
111
+
112
+ def get_model_info(self) -> Dict[str, Any]:
113
+ """Get model information."""
114
+ return {
115
+ "model_path": self.model_path,
116
+ "device": self.device,
117
+ "num_classes": len(self._class_names) if self._class_names else None,
118
+ "class_names": self._class_names
119
+ }
120
+
121
+ def warmup(self, image_size: int = 720) -> None:
122
+ """
123
+ Warm up the model with a dummy inference.
124
+
125
+ Args:
126
+ image_size: Size for warmup inference
127
+ """
128
+ if self.model is None:
129
+ return
130
+
131
+ try:
132
+ logger.info("Warming up model...")
133
+ # Create a dummy image
134
+ dummy_image = np.random.randint(0, 255, (image_size, image_size, 3), dtype=np.uint8)
135
+ self.predict(dummy_image, conf_threshold=0.1)
136
+ logger.info("Model warmup completed")
137
+ except Exception as e:
138
+ logger.warning(f"Model warmup failed: {str(e)}")
139
+
140
+
141
+ # Global model instance (singleton pattern)
142
+ _model_instance: Optional[MarineSpeciesYOLO] = None
143
+
144
+
145
+ def get_model() -> MarineSpeciesYOLO:
146
+ """
147
+ Get the global model instance (singleton pattern).
148
+
149
+ Returns:
150
+ MarineSpeciesYOLO instance
151
+ """
152
+ global _model_instance
153
+
154
+ if _model_instance is None:
155
+ _model_instance = MarineSpeciesYOLO(
156
+ model_path=settings.MODEL_PATH,
157
+ device=settings.DEVICE
158
+ )
159
+
160
+ # Warm up the model if enabled
161
+ if settings.ENABLE_MODEL_WARMUP:
162
+ _model_instance.warmup()
163
+
164
+ return _model_instance
165
+
166
+
167
+ def load_class_names(names_file: str) -> Dict[int, str]:
168
+ """
169
+ Load class names from a .names file.
170
+
171
+ Args:
172
+ names_file: Path to the .names file
173
+
174
+ Returns:
175
+ Dictionary mapping class IDs to names
176
+ """
177
+ class_names = {}
178
+ try:
179
+ with open(names_file, 'r') as f:
180
+ for idx, line in enumerate(f):
181
+ class_names[idx] = line.strip()
182
+ logger.info(f"Loaded {len(class_names)} class names from {names_file}")
183
+ except Exception as e:
184
+ logger.error(f"Failed to load class names: {str(e)}")
185
+
186
+ return class_names
app/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Services layer for business logic
app/services/inference_service.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference service for marine species detection.
3
+ """
4
+
5
+ import time
6
+ import base64
7
+ import io
8
+ from typing import List, Optional, Dict, Tuple
9
+ import numpy as np
10
+ from PIL import Image
11
+ import cv2
12
+
13
+ from app.core.config import settings
14
+ from app.core.logging import get_logger
15
+ from app.models.inference import Detection, BoundingBox, InferenceResponse, ModelInfo
16
+ from app.services.model_service import model_service
17
+ from app.utils.image_processing import decode_base64_image, encode_image_to_base64
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ class InferenceService:
23
+ """Service for running marine species detection inference."""
24
+
25
+ def __init__(self):
26
+ self.model_service = model_service
27
+
28
+ async def detect_species(
29
+ self,
30
+ image_data: str,
31
+ confidence_threshold: float = 0.25,
32
+ iou_threshold: float = 0.45,
33
+ image_size: int = 720,
34
+ return_annotated_image: bool = True,
35
+ classes: Optional[List[int]] = None
36
+ ) -> InferenceResponse:
37
+ """
38
+ Detect marine species in an image.
39
+
40
+ Args:
41
+ image_data: Base64 encoded image
42
+ confidence_threshold: Confidence threshold for detections
43
+ iou_threshold: IoU threshold for NMS
44
+ image_size: Input image size for inference
45
+ return_annotated_image: Whether to return annotated image
46
+ classes: List of class IDs to filter
47
+
48
+ Returns:
49
+ InferenceResponse with detection results
50
+ """
51
+ start_time = time.time()
52
+
53
+ try:
54
+ # Decode the image
55
+ image, original_dims = decode_base64_image(image_data)
56
+ logger.info(f"Processing image with dimensions: {original_dims}")
57
+
58
+ # Get the model
59
+ model = self.model_service.get_model()
60
+
61
+ # Run inference
62
+ predictions = model.predict(
63
+ image=image,
64
+ conf_threshold=confidence_threshold,
65
+ iou_threshold=iou_threshold,
66
+ image_size=image_size,
67
+ classes=classes
68
+ )
69
+
70
+ # Process predictions
71
+ detections = self._process_predictions(predictions)
72
+
73
+ # Generate annotated image if requested
74
+ annotated_image_b64 = None
75
+ if return_annotated_image and detections:
76
+ annotated_image = self._create_annotated_image(image, predictions)
77
+ annotated_image_b64 = encode_image_to_base64(annotated_image)
78
+
79
+ # Get model info
80
+ model_info_dict = self.model_service.get_model_info()
81
+ model_info = ModelInfo(**model_info_dict)
82
+
83
+ processing_time = time.time() - start_time
84
+
85
+ logger.info(f"Inference completed in {processing_time:.3f}s, found {len(detections)} detections")
86
+
87
+ return InferenceResponse(
88
+ detections=detections,
89
+ annotated_image=annotated_image_b64,
90
+ processing_time=processing_time,
91
+ model_info=model_info,
92
+ image_dimensions={"width": original_dims[0], "height": original_dims[1]}
93
+ )
94
+
95
+ except Exception as e:
96
+ logger.error(f"Inference failed: {str(e)}")
97
+ raise
98
+
99
+ def _process_predictions(self, predictions) -> List[Detection]:
100
+ """
101
+ Process YOLOv5 predictions into Detection objects.
102
+
103
+ Args:
104
+ predictions: YOLOv5 prediction results
105
+
106
+ Returns:
107
+ List of Detection objects
108
+ """
109
+ detections = []
110
+ class_names = self.model_service.get_class_names()
111
+
112
+ try:
113
+ # Get predictions as pandas DataFrame
114
+ pred_df = predictions.pandas().xyxy[0]
115
+
116
+ for _, row in pred_df.iterrows():
117
+ # Extract bounding box coordinates
118
+ x1, y1, x2, y2 = row['xmin'], row['ymin'], row['xmax'], row['ymax']
119
+ width = x2 - x1
120
+ height = y2 - y1
121
+
122
+ # Get class information
123
+ class_id = int(row['class'])
124
+ confidence = float(row['confidence'])
125
+
126
+ # Get class name
127
+ if class_names and class_id in class_names:
128
+ class_name = class_names[class_id]
129
+ else:
130
+ class_name = f"class_{class_id}"
131
+
132
+ # Create detection object
133
+ detection = Detection(
134
+ class_id=class_id,
135
+ class_name=class_name,
136
+ confidence=confidence,
137
+ bbox=BoundingBox(
138
+ x=float(x1),
139
+ y=float(y1),
140
+ width=float(width),
141
+ height=float(height)
142
+ )
143
+ )
144
+
145
+ detections.append(detection)
146
+
147
+ except Exception as e:
148
+ logger.error(f"Failed to process predictions: {str(e)}")
149
+ raise
150
+
151
+ return detections
152
+
153
+ def _create_annotated_image(self, original_image: np.ndarray, predictions) -> np.ndarray:
154
+ """
155
+ Create an annotated image with detection boxes and labels.
156
+
157
+ Args:
158
+ original_image: Original input image
159
+ predictions: YOLOv5 prediction results
160
+
161
+ Returns:
162
+ Annotated image as numpy array
163
+ """
164
+ try:
165
+ # Use YOLOv5's built-in rendering
166
+ rendered_imgs = predictions.render()
167
+ if rendered_imgs and len(rendered_imgs) > 0:
168
+ return rendered_imgs[0]
169
+ else:
170
+ # Fallback: return original image if rendering fails
171
+ return original_image
172
+
173
+ except Exception as e:
174
+ logger.error(f"Failed to create annotated image: {str(e)}")
175
+ # Return original image as fallback
176
+ return original_image
177
+
178
+ async def get_supported_species(self) -> List[Dict]:
179
+ """
180
+ Get list of all supported marine species.
181
+
182
+ Returns:
183
+ List of species information
184
+ """
185
+ class_names = self.model_service.get_class_names()
186
+
187
+ if not class_names:
188
+ return []
189
+
190
+ species_list = []
191
+ for class_id, class_name in class_names.items():
192
+ species_list.append({
193
+ "class_id": class_id,
194
+ "class_name": class_name
195
+ })
196
+
197
+ return sorted(species_list, key=lambda x: x["class_name"])
198
+
199
+
200
+ # Global service instance
201
+ inference_service = InferenceService()
app/services/model_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model service for managing YOLOv5 model lifecycle and operations.
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Dict, Optional
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from app.core.config import settings
11
+ from app.core.logging import get_logger
12
+ from app.models.yolo import MarineSpeciesYOLO, get_model
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ class ModelService:
18
+ """Service for managing the marine species detection model."""
19
+
20
+ def __init__(self):
21
+ self._model: Optional[MarineSpeciesYOLO] = None
22
+ self._class_names: Optional[Dict[int, str]] = None
23
+
24
+ async def ensure_model_available(self) -> None:
25
+ """
26
+ Ensure the model is downloaded and available.
27
+ Downloads from HuggingFace Hub if not present locally.
28
+ """
29
+ model_path = Path(settings.MODEL_PATH)
30
+
31
+ # Check if model exists locally
32
+ if not model_path.exists():
33
+ logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
34
+ await self._download_model()
35
+
36
+ # Load class names if available
37
+ await self._load_class_names()
38
+
39
+ async def _download_model(self) -> None:
40
+ """Download model from HuggingFace Hub."""
41
+ try:
42
+ # Create models directory if it doesn't exist
43
+ model_dir = Path(settings.MODEL_PATH).parent
44
+ model_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+ # Download the model file
47
+ logger.info(f"Downloading model from {settings.HUGGINGFACE_REPO}")
48
+
49
+ # Download the .pt model file
50
+ model_filename = f"{settings.MODEL_NAME}.pt"
51
+ downloaded_path = hf_hub_download(
52
+ repo_id=settings.HUGGINGFACE_REPO,
53
+ filename=model_filename,
54
+ cache_dir=str(model_dir.parent / ".cache"),
55
+ local_dir=str(model_dir),
56
+ local_dir_use_symlinks=False
57
+ )
58
+
59
+ logger.info(f"Model downloaded successfully to: {downloaded_path}")
60
+
61
+ # Also download the .names file if available
62
+ try:
63
+ names_filename = f"{settings.MODEL_NAME}.names"
64
+ names_path = hf_hub_download(
65
+ repo_id=settings.HUGGINGFACE_REPO,
66
+ filename=names_filename,
67
+ cache_dir=str(model_dir.parent / ".cache"),
68
+ local_dir=str(model_dir),
69
+ local_dir_use_symlinks=False
70
+ )
71
+ logger.info(f"Class names file downloaded to: {names_path}")
72
+ except Exception as e:
73
+ logger.warning(f"Could not download .names file: {str(e)}")
74
+
75
+ except Exception as e:
76
+ logger.error(f"Failed to download model: {str(e)}")
77
+ raise RuntimeError(f"Model download failed: {str(e)}")
78
+
79
+ async def _load_class_names(self) -> None:
80
+ """Load class names from .names file."""
81
+ names_file = Path(settings.MODEL_PATH).with_suffix('.names')
82
+
83
+ if names_file.exists():
84
+ try:
85
+ class_names = {}
86
+ with open(names_file, 'r') as f:
87
+ for idx, line in enumerate(f):
88
+ class_names[idx] = line.strip()
89
+
90
+ self._class_names = class_names
91
+ logger.info(f"Loaded {len(class_names)} class names")
92
+ except Exception as e:
93
+ logger.error(f"Failed to load class names: {str(e)}")
94
+ else:
95
+ logger.warning(f"Class names file not found: {names_file}")
96
+
97
+ def get_model(self) -> MarineSpeciesYOLO:
98
+ """
99
+ Get the model instance.
100
+
101
+ Returns:
102
+ MarineSpeciesYOLO instance
103
+ """
104
+ if self._model is None:
105
+ self._model = get_model()
106
+ return self._model
107
+
108
+ def get_class_names(self) -> Optional[Dict[int, str]]:
109
+ """
110
+ Get class names mapping.
111
+
112
+ Returns:
113
+ Dictionary mapping class IDs to names
114
+ """
115
+ if self._class_names is None:
116
+ # Try to get from model
117
+ model = self.get_model()
118
+ self._class_names = model.get_class_names()
119
+
120
+ return self._class_names
121
+
122
+ def get_model_info(self) -> Dict:
123
+ """
124
+ Get comprehensive model information.
125
+
126
+ Returns:
127
+ Dictionary with model information
128
+ """
129
+ model = self.get_model()
130
+ class_names = self.get_class_names()
131
+
132
+ return {
133
+ "model_name": settings.MODEL_NAME,
134
+ "total_classes": len(class_names) if class_names else 0,
135
+ "device": model.device,
136
+ "model_path": settings.MODEL_PATH,
137
+ "huggingface_repo": settings.HUGGINGFACE_REPO
138
+ }
139
+
140
+ async def health_check(self) -> Dict:
141
+ """
142
+ Perform a health check on the model.
143
+
144
+ Returns:
145
+ Dictionary with health status
146
+ """
147
+ try:
148
+ model = self.get_model()
149
+ model_info = self.get_model_info()
150
+
151
+ return {
152
+ "status": "healthy",
153
+ "model_loaded": True,
154
+ "model_info": model_info
155
+ }
156
+ except Exception as e:
157
+ logger.error(f"Model health check failed: {str(e)}")
158
+ return {
159
+ "status": "unhealthy",
160
+ "model_loaded": False,
161
+ "error": str(e)
162
+ }
163
+
164
+
165
+ # Global service instance
166
+ model_service = ModelService()
app/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utility functions
app/utils/image_processing.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image processing utilities for the FastAPI application.
3
+ """
4
+
5
+ import base64
6
+ import io
7
+ from typing import Tuple
8
+ import numpy as np
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ from app.core.config import settings
13
+ from app.core.logging import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ def decode_base64_image(image_data: str) -> Tuple[np.ndarray, Tuple[int, int]]:
19
+ """
20
+ Decode base64 image data to numpy array.
21
+
22
+ Args:
23
+ image_data: Base64 encoded image string
24
+
25
+ Returns:
26
+ Tuple of (image_array, (width, height))
27
+ """
28
+ try:
29
+ # Remove data URL prefix if present
30
+ if image_data.startswith('data:image'):
31
+ image_data = image_data.split(',')[1]
32
+
33
+ # Decode base64
34
+ image_bytes = base64.b64decode(image_data)
35
+
36
+ # Open with PIL
37
+ pil_image = Image.open(io.BytesIO(image_bytes))
38
+
39
+ # Convert to RGB if necessary
40
+ if pil_image.mode != 'RGB':
41
+ pil_image = pil_image.convert('RGB')
42
+
43
+ # Get original dimensions
44
+ original_dims = pil_image.size # (width, height)
45
+
46
+ # Convert to numpy array
47
+ image_array = np.array(pil_image)
48
+
49
+ logger.debug(f"Decoded image with shape: {image_array.shape}")
50
+
51
+ return image_array, original_dims
52
+
53
+ except Exception as e:
54
+ logger.error(f"Failed to decode base64 image: {str(e)}")
55
+ raise ValueError(f"Invalid image data: {str(e)}")
56
+
57
+
58
+ def encode_image_to_base64(image: np.ndarray, format: str = "JPEG", quality: int = 95) -> str:
59
+ """
60
+ Encode numpy array image to base64 string.
61
+
62
+ Args:
63
+ image: Image as numpy array
64
+ format: Image format (JPEG, PNG, etc.)
65
+ quality: JPEG quality (1-100)
66
+
67
+ Returns:
68
+ Base64 encoded image string
69
+ """
70
+ try:
71
+ # Convert numpy array to PIL Image
72
+ if image.dtype != np.uint8:
73
+ image = (image * 255).astype(np.uint8)
74
+
75
+ pil_image = Image.fromarray(image)
76
+
77
+ # Save to bytes buffer
78
+ buffer = io.BytesIO()
79
+ save_kwargs = {"format": format}
80
+
81
+ if format.upper() == "JPEG":
82
+ save_kwargs["quality"] = quality
83
+ save_kwargs["optimize"] = True
84
+
85
+ pil_image.save(buffer, **save_kwargs)
86
+
87
+ # Encode to base64
88
+ image_bytes = buffer.getvalue()
89
+ base64_string = base64.b64encode(image_bytes).decode('utf-8')
90
+
91
+ return base64_string
92
+
93
+ except Exception as e:
94
+ logger.error(f"Failed to encode image to base64: {str(e)}")
95
+ raise ValueError(f"Image encoding failed: {str(e)}")
96
+
97
+
98
+ def validate_image_size(image: np.ndarray) -> bool:
99
+ """
100
+ Validate image dimensions.
101
+
102
+ Args:
103
+ image: Image as numpy array
104
+
105
+ Returns:
106
+ True if image size is valid
107
+ """
108
+ height, width = image.shape[:2]
109
+
110
+ # Check minimum and maximum dimensions
111
+ min_dim = min(width, height)
112
+ max_dim = max(width, height)
113
+
114
+ if min_dim < 32: # Too small
115
+ return False
116
+
117
+ if max_dim > 4096: # Too large
118
+ return False
119
+
120
+ return True
121
+
122
+
123
+ def resize_image_if_needed(image: np.ndarray, max_size: int = 1280) -> np.ndarray:
124
+ """
125
+ Resize image if it's too large while maintaining aspect ratio.
126
+
127
+ Args:
128
+ image: Image as numpy array
129
+ max_size: Maximum dimension size
130
+
131
+ Returns:
132
+ Resized image
133
+ """
134
+ height, width = image.shape[:2]
135
+
136
+ if max(height, width) <= max_size:
137
+ return image
138
+
139
+ # Calculate new dimensions
140
+ if width > height:
141
+ new_width = max_size
142
+ new_height = int(height * (max_size / width))
143
+ else:
144
+ new_height = max_size
145
+ new_width = int(width * (max_size / height))
146
+
147
+ # Resize using PIL for better quality
148
+ pil_image = Image.fromarray(image)
149
+ resized_pil = pil_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
150
+
151
+ return np.array(resized_pil)
152
+
153
+
154
+ def validate_image_format(image_bytes: bytes) -> bool:
155
+ """
156
+ Validate if the image format is supported.
157
+
158
+ Args:
159
+ image_bytes: Raw image bytes
160
+
161
+ Returns:
162
+ True if format is supported
163
+ """
164
+ try:
165
+ with Image.open(io.BytesIO(image_bytes)) as img:
166
+ # Check if format is in allowed extensions
167
+ format_lower = img.format.lower() if img.format else ""
168
+ allowed_formats = {"jpeg", "jpg", "png", "bmp", "tiff", "webp"}
169
+ return format_lower in allowed_formats
170
+ except Exception:
171
+ return False
172
+
173
+
174
+ def get_image_info(image: np.ndarray) -> dict:
175
+ """
176
+ Get information about an image.
177
+
178
+ Args:
179
+ image: Image as numpy array
180
+
181
+ Returns:
182
+ Dictionary with image information
183
+ """
184
+ height, width = image.shape[:2]
185
+ channels = image.shape[2] if len(image.shape) > 2 else 1
186
+
187
+ return {
188
+ "width": width,
189
+ "height": height,
190
+ "channels": channels,
191
+ "dtype": str(image.dtype),
192
+ "size_mb": image.nbytes / (1024 * 1024)
193
+ }
app/utils/model_utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model utilities for downloading and managing the marine species model.
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ from pathlib import Path
8
+ from typing import Optional, Dict, Any
9
+ from huggingface_hub import hf_hub_download, list_repo_files
10
+
11
+ from app.core.config import settings
12
+ from app.core.logging import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def download_model_from_hf(
18
+ repo_id: str,
19
+ model_filename: str,
20
+ local_dir: str,
21
+ force_download: bool = False
22
+ ) -> str:
23
+ """
24
+ Download model from HuggingFace Hub.
25
+
26
+ Args:
27
+ repo_id: HuggingFace repository ID
28
+ model_filename: Name of the model file
29
+ local_dir: Local directory to save the model
30
+ force_download: Whether to force re-download if file exists
31
+
32
+ Returns:
33
+ Path to the downloaded model file
34
+ """
35
+ try:
36
+ # Create local directory if it doesn't exist
37
+ Path(local_dir).mkdir(parents=True, exist_ok=True)
38
+
39
+ local_path = Path(local_dir) / model_filename
40
+
41
+ # Check if file already exists and force_download is False
42
+ if local_path.exists() and not force_download:
43
+ logger.info(f"Model already exists at {local_path}")
44
+ return str(local_path)
45
+
46
+ logger.info(f"Downloading {model_filename} from {repo_id}...")
47
+
48
+ downloaded_path = hf_hub_download(
49
+ repo_id=repo_id,
50
+ filename=model_filename,
51
+ local_dir=local_dir,
52
+ local_dir_use_symlinks=False,
53
+ force_download=force_download
54
+ )
55
+
56
+ logger.info(f"Model downloaded successfully to: {downloaded_path}")
57
+ return downloaded_path
58
+
59
+ except Exception as e:
60
+ logger.error(f"Failed to download model: {str(e)}")
61
+ raise
62
+
63
+
64
+ def list_available_files(repo_id: str) -> list:
65
+ """
66
+ List all available files in a HuggingFace repository.
67
+
68
+ Args:
69
+ repo_id: HuggingFace repository ID
70
+
71
+ Returns:
72
+ List of available files
73
+ """
74
+ try:
75
+ files = list_repo_files(repo_id)
76
+ return files
77
+ except Exception as e:
78
+ logger.error(f"Failed to list repository files: {str(e)}")
79
+ return []
80
+
81
+
82
+ def verify_model_file(model_path: str) -> bool:
83
+ """
84
+ Verify that a model file exists and is valid.
85
+
86
+ Args:
87
+ model_path: Path to the model file
88
+
89
+ Returns:
90
+ True if model file is valid
91
+ """
92
+ try:
93
+ path = Path(model_path)
94
+
95
+ # Check if file exists
96
+ if not path.exists():
97
+ logger.error(f"Model file does not exist: {model_path}")
98
+ return False
99
+
100
+ # Check file size (should be > 1MB for a real model)
101
+ file_size = path.stat().st_size
102
+ if file_size < 1024 * 1024: # 1MB
103
+ logger.warning(f"Model file seems too small: {file_size} bytes")
104
+ return False
105
+
106
+ # Check file extension
107
+ if not path.suffix.lower() in ['.pt', '.pth']:
108
+ logger.warning(f"Unexpected model file extension: {path.suffix}")
109
+
110
+ logger.info(f"Model file verified: {model_path} ({file_size / (1024*1024):.1f} MB)")
111
+ return True
112
+
113
+ except Exception as e:
114
+ logger.error(f"Failed to verify model file: {str(e)}")
115
+ return False
116
+
117
+
118
+ def get_model_info(model_path: str) -> Dict[str, Any]:
119
+ """
120
+ Get information about a model file.
121
+
122
+ Args:
123
+ model_path: Path to the model file
124
+
125
+ Returns:
126
+ Dictionary with model information
127
+ """
128
+ info = {
129
+ "path": model_path,
130
+ "exists": False,
131
+ "size_mb": 0,
132
+ "size_bytes": 0
133
+ }
134
+
135
+ try:
136
+ path = Path(model_path)
137
+
138
+ if path.exists():
139
+ info["exists"] = True
140
+ size_bytes = path.stat().st_size
141
+ info["size_bytes"] = size_bytes
142
+ info["size_mb"] = size_bytes / (1024 * 1024)
143
+ info["modified_time"] = path.stat().st_mtime
144
+
145
+ except Exception as e:
146
+ logger.error(f"Failed to get model info: {str(e)}")
147
+
148
+ return info
149
+
150
+
151
+ def cleanup_model_cache(cache_dir: Optional[str] = None) -> None:
152
+ """
153
+ Clean up model cache directory.
154
+
155
+ Args:
156
+ cache_dir: Cache directory to clean (uses default if None)
157
+ """
158
+ try:
159
+ if cache_dir is None:
160
+ cache_dir = Path.home() / ".cache" / "huggingface"
161
+
162
+ cache_path = Path(cache_dir)
163
+
164
+ if cache_path.exists():
165
+ logger.info(f"Cleaning up cache directory: {cache_path}")
166
+ shutil.rmtree(cache_path)
167
+ logger.info("Cache cleanup completed")
168
+ else:
169
+ logger.info("Cache directory does not exist")
170
+
171
+ except Exception as e:
172
+ logger.error(f"Failed to cleanup cache: {str(e)}")
173
+
174
+
175
+ def setup_model_directory() -> str:
176
+ """
177
+ Setup the model directory and ensure it exists.
178
+
179
+ Returns:
180
+ Path to the model directory
181
+ """
182
+ model_dir = Path(settings.MODEL_PATH).parent
183
+ model_dir.mkdir(parents=True, exist_ok=True)
184
+
185
+ logger.info(f"Model directory setup: {model_dir}")
186
+ return str(model_dir)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ # Command line utility for model management
191
+ import argparse
192
+
193
+ parser = argparse.ArgumentParser(description="Model management utility")
194
+ parser.add_argument("--download", action="store_true", help="Download model from HuggingFace")
195
+ parser.add_argument("--verify", action="store_true", help="Verify model file")
196
+ parser.add_argument("--info", action="store_true", help="Show model information")
197
+ parser.add_argument("--list-files", action="store_true", help="List available files in HF repo")
198
+ parser.add_argument("--cleanup-cache", action="store_true", help="Cleanup model cache")
199
+ parser.add_argument("--force", action="store_true", help="Force download even if file exists")
200
+
201
+ args = parser.parse_args()
202
+
203
+ if args.download:
204
+ setup_model_directory()
205
+ download_model_from_hf(
206
+ repo_id=settings.HUGGINGFACE_REPO,
207
+ model_filename=f"{settings.MODEL_NAME}.pt",
208
+ local_dir=str(Path(settings.MODEL_PATH).parent),
209
+ force_download=args.force
210
+ )
211
+
212
+ if args.verify:
213
+ is_valid = verify_model_file(settings.MODEL_PATH)
214
+ print(f"Model valid: {is_valid}")
215
+
216
+ if args.info:
217
+ info = get_model_info(settings.MODEL_PATH)
218
+ print(f"Model info: {info}")
219
+
220
+ if args.list_files:
221
+ files = list_available_files(settings.HUGGINGFACE_REPO)
222
+ print(f"Available files in {settings.HUGGINGFACE_REPO}:")
223
+ for file in files:
224
+ print(f" - {file}")
225
+
226
+ if args.cleanup_cache:
227
+ cleanup_model_cache()
app/utils/performance.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Performance monitoring and optimization utilities.
3
+ """
4
+
5
+ import time
6
+ import psutil
7
+ import functools
8
+ from typing import Dict, Any, Callable, Optional
9
+ from contextlib import contextmanager
10
+
11
+ from app.core.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class PerformanceMonitor:
17
+ """Performance monitoring utility."""
18
+
19
+ def __init__(self):
20
+ self.metrics = {
21
+ "requests_total": 0,
22
+ "requests_successful": 0,
23
+ "requests_failed": 0,
24
+ "total_processing_time": 0.0,
25
+ "average_processing_time": 0.0,
26
+ "min_processing_time": float('inf'),
27
+ "max_processing_time": 0.0
28
+ }
29
+
30
+ def record_request(self, processing_time: float, success: bool = True):
31
+ """Record a request's performance metrics."""
32
+ self.metrics["requests_total"] += 1
33
+
34
+ if success:
35
+ self.metrics["requests_successful"] += 1
36
+ else:
37
+ self.metrics["requests_failed"] += 1
38
+
39
+ self.metrics["total_processing_time"] += processing_time
40
+ self.metrics["average_processing_time"] = (
41
+ self.metrics["total_processing_time"] / self.metrics["requests_total"]
42
+ )
43
+
44
+ if processing_time < self.metrics["min_processing_time"]:
45
+ self.metrics["min_processing_time"] = processing_time
46
+
47
+ if processing_time > self.metrics["max_processing_time"]:
48
+ self.metrics["max_processing_time"] = processing_time
49
+
50
+ def get_metrics(self) -> Dict[str, Any]:
51
+ """Get current performance metrics."""
52
+ metrics = self.metrics.copy()
53
+
54
+ # Add system metrics
55
+ try:
56
+ metrics.update({
57
+ "cpu_percent": psutil.cpu_percent(),
58
+ "memory_percent": psutil.virtual_memory().percent,
59
+ "memory_available_mb": psutil.virtual_memory().available / (1024 * 1024)
60
+ })
61
+ except Exception as e:
62
+ logger.warning(f"Failed to get system metrics: {str(e)}")
63
+
64
+ return metrics
65
+
66
+ def reset_metrics(self):
67
+ """Reset all metrics."""
68
+ self.metrics = {
69
+ "requests_total": 0,
70
+ "requests_successful": 0,
71
+ "requests_failed": 0,
72
+ "total_processing_time": 0.0,
73
+ "average_processing_time": 0.0,
74
+ "min_processing_time": float('inf'),
75
+ "max_processing_time": 0.0
76
+ }
77
+
78
+
79
+ # Global performance monitor instance
80
+ performance_monitor = PerformanceMonitor()
81
+
82
+
83
+ @contextmanager
84
+ def measure_time():
85
+ """Context manager to measure execution time."""
86
+ start_time = time.time()
87
+ try:
88
+ yield
89
+ finally:
90
+ end_time = time.time()
91
+ execution_time = end_time - start_time
92
+ logger.debug(f"Execution time: {execution_time:.3f}s")
93
+
94
+
95
+ def timed_function(func: Callable) -> Callable:
96
+ """Decorator to measure function execution time."""
97
+ @functools.wraps(func)
98
+ def wrapper(*args, **kwargs):
99
+ start_time = time.time()
100
+ try:
101
+ result = func(*args, **kwargs)
102
+ success = True
103
+ return result
104
+ except Exception as e:
105
+ success = False
106
+ raise
107
+ finally:
108
+ end_time = time.time()
109
+ execution_time = end_time - start_time
110
+ performance_monitor.record_request(execution_time, success)
111
+ logger.debug(f"{func.__name__} execution time: {execution_time:.3f}s")
112
+
113
+ return wrapper
114
+
115
+
116
+ async def timed_async_function(func: Callable) -> Callable:
117
+ """Decorator to measure async function execution time."""
118
+ @functools.wraps(func)
119
+ async def wrapper(*args, **kwargs):
120
+ start_time = time.time()
121
+ try:
122
+ result = await func(*args, **kwargs)
123
+ success = True
124
+ return result
125
+ except Exception as e:
126
+ success = False
127
+ raise
128
+ finally:
129
+ end_time = time.time()
130
+ execution_time = end_time - start_time
131
+ performance_monitor.record_request(execution_time, success)
132
+ logger.debug(f"{func.__name__} execution time: {execution_time:.3f}s")
133
+
134
+ return wrapper
135
+
136
+
137
+ class SimpleCache:
138
+ """Simple in-memory cache for inference results."""
139
+
140
+ def __init__(self, max_size: int = 100, ttl: int = 3600):
141
+ """
142
+ Initialize cache.
143
+
144
+ Args:
145
+ max_size: Maximum number of items to cache
146
+ ttl: Time to live in seconds
147
+ """
148
+ self.max_size = max_size
149
+ self.ttl = ttl
150
+ self.cache = {}
151
+ self.access_times = {}
152
+
153
+ def _is_expired(self, key: str) -> bool:
154
+ """Check if a cache entry is expired."""
155
+ if key not in self.access_times:
156
+ return True
157
+
158
+ return time.time() - self.access_times[key] > self.ttl
159
+
160
+ def _evict_expired(self):
161
+ """Remove expired entries."""
162
+ current_time = time.time()
163
+ expired_keys = [
164
+ key for key, access_time in self.access_times.items()
165
+ if current_time - access_time > self.ttl
166
+ ]
167
+
168
+ for key in expired_keys:
169
+ self.cache.pop(key, None)
170
+ self.access_times.pop(key, None)
171
+
172
+ def _evict_lru(self):
173
+ """Remove least recently used entry."""
174
+ if not self.access_times:
175
+ return
176
+
177
+ lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
178
+ self.cache.pop(lru_key, None)
179
+ self.access_times.pop(lru_key, None)
180
+
181
+ def get(self, key: str) -> Optional[Any]:
182
+ """Get item from cache."""
183
+ if key not in self.cache or self._is_expired(key):
184
+ return None
185
+
186
+ self.access_times[key] = time.time()
187
+ return self.cache[key]
188
+
189
+ def set(self, key: str, value: Any):
190
+ """Set item in cache."""
191
+ # Clean up expired entries
192
+ self._evict_expired()
193
+
194
+ # Evict LRU if at max size
195
+ while len(self.cache) >= self.max_size:
196
+ self._evict_lru()
197
+
198
+ self.cache[key] = value
199
+ self.access_times[key] = time.time()
200
+
201
+ def clear(self):
202
+ """Clear all cache entries."""
203
+ self.cache.clear()
204
+ self.access_times.clear()
205
+
206
+ def size(self) -> int:
207
+ """Get current cache size."""
208
+ return len(self.cache)
209
+
210
+ def stats(self) -> Dict[str, Any]:
211
+ """Get cache statistics."""
212
+ return {
213
+ "size": len(self.cache),
214
+ "max_size": self.max_size,
215
+ "ttl": self.ttl,
216
+ "hit_ratio": getattr(self, '_hits', 0) / max(getattr(self, '_requests', 1), 1)
217
+ }
218
+
219
+
220
+ # Global cache instance
221
+ inference_cache = SimpleCache(max_size=50, ttl=1800) # 30 minutes TTL
222
+
223
+
224
+ def get_system_info() -> Dict[str, Any]:
225
+ """Get system information."""
226
+ try:
227
+ return {
228
+ "cpu_count": psutil.cpu_count(),
229
+ "cpu_percent": psutil.cpu_percent(interval=1),
230
+ "memory_total_gb": psutil.virtual_memory().total / (1024**3),
231
+ "memory_available_gb": psutil.virtual_memory().available / (1024**3),
232
+ "memory_percent": psutil.virtual_memory().percent,
233
+ "disk_usage_percent": psutil.disk_usage('/').percent
234
+ }
235
+ except Exception as e:
236
+ logger.error(f"Failed to get system info: {str(e)}")
237
+ return {"error": str(e)}
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI and web server
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+
5
+ # Machine Learning and Computer Vision
6
+ torch>=1.13.0
7
+ torchvision>=0.14.0
8
+ yolov5==7.0.13
9
+ opencv-python-headless==4.8.1.78
10
+ Pillow==10.1.0
11
+ numpy==1.24.3
12
+
13
+ # Data handling and validation
14
+ pydantic==2.5.0
15
+ pydantic-settings==2.1.0
16
+
17
+ # HuggingFace integration
18
+ huggingface-hub==0.19.4
19
+
20
+ # Utilities
21
+ python-multipart==0.0.6
22
+ aiofiles==23.2.1
23
+
24
+ # Performance monitoring and optimization
25
+ psutil==5.9.6
26
+
27
+ # Testing (optional, for development)
28
+ pytest==7.4.3
29
+ httpx==0.25.2
run_tests.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test runner script for the Marine Species Identification API.
4
+ """
5
+
6
+ import subprocess
7
+ import sys
8
+ import os
9
+ from pathlib import Path
10
+
11
+ def run_pytest():
12
+ """Run pytest tests."""
13
+ print("🧪 Running pytest tests...")
14
+ try:
15
+ result = subprocess.run([
16
+ sys.executable, "-m", "pytest",
17
+ "tests/",
18
+ "-v",
19
+ "--tb=short"
20
+ ], check=True)
21
+ print("✅ All pytest tests passed!")
22
+ return True
23
+ except subprocess.CalledProcessError as e:
24
+ print(f"❌ Some pytest tests failed (exit code: {e.returncode})")
25
+ return False
26
+ except FileNotFoundError:
27
+ print("⚠️ pytest not found, skipping pytest tests")
28
+ return True
29
+
30
+ def run_simple_test():
31
+ """Run the simple API test."""
32
+ print("🧪 Running simple API test...")
33
+ try:
34
+ result = subprocess.run([
35
+ sys.executable, "test_api_simple.py"
36
+ ], check=True)
37
+ print("✅ Simple API test completed!")
38
+ return True
39
+ except subprocess.CalledProcessError as e:
40
+ print(f"❌ Simple API test failed (exit code: {e.returncode})")
41
+ return False
42
+
43
+ def main():
44
+ """Main test runner."""
45
+ print("🐟 Marine Species Identification API - Test Runner")
46
+ print("=" * 55)
47
+
48
+ # Change to project directory
49
+ project_dir = Path(__file__).parent
50
+ os.chdir(project_dir)
51
+
52
+ success = True
53
+
54
+ # Run pytest tests
55
+ if not run_pytest():
56
+ success = False
57
+
58
+ print()
59
+
60
+ # Note about simple test
61
+ print("📝 Note: To run the simple API test, start the API first:")
62
+ print(" python start_api.py")
63
+ print(" # Then in another terminal:")
64
+ print(" python test_api_simple.py")
65
+
66
+ print("=" * 55)
67
+
68
+ if success:
69
+ print("🎉 All available tests completed successfully!")
70
+ return 0
71
+ else:
72
+ print("❌ Some tests failed")
73
+ return 1
74
+
75
+ if __name__ == "__main__":
76
+ sys.exit(main())
start_api.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Startup script for the Marine Species Identification API.
4
+ This script handles model downloading and API startup.
5
+ """
6
+
7
+ import asyncio
8
+ import sys
9
+ import os
10
+ from pathlib import Path
11
+
12
+ # Add the app directory to Python path
13
+ sys.path.insert(0, str(Path(__file__).parent))
14
+
15
+ from app.core.config import settings
16
+ from app.core.logging import setup_logging, get_logger
17
+ from app.utils.model_utils import (
18
+ download_model_from_hf,
19
+ verify_model_file,
20
+ setup_model_directory,
21
+ list_available_files
22
+ )
23
+
24
+ # Setup logging
25
+ setup_logging()
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ async def ensure_model_available():
30
+ """Ensure the model is downloaded and available."""
31
+ logger.info("🔍 Checking model availability...")
32
+
33
+ # Setup model directory
34
+ model_dir = setup_model_directory()
35
+ logger.info(f"Model directory: {model_dir}")
36
+
37
+ # Check if model file exists
38
+ if verify_model_file(settings.MODEL_PATH):
39
+ logger.info("✅ Model file found and verified")
40
+ return True
41
+
42
+ logger.info("📥 Model not found locally, attempting to download...")
43
+
44
+ try:
45
+ # List available files in the repository
46
+ logger.info(f"Checking repository: {settings.HUGGINGFACE_REPO}")
47
+ available_files = list_available_files(settings.HUGGINGFACE_REPO)
48
+
49
+ if available_files:
50
+ logger.info(f"Available files in repository:")
51
+ for file in available_files[:10]: # Show first 10 files
52
+ logger.info(f" - {file}")
53
+ if len(available_files) > 10:
54
+ logger.info(f" ... and {len(available_files) - 10} more files")
55
+
56
+ # Download the model
57
+ model_filename = f"{settings.MODEL_NAME}.pt"
58
+
59
+ if model_filename in available_files:
60
+ download_model_from_hf(
61
+ repo_id=settings.HUGGINGFACE_REPO,
62
+ model_filename=model_filename,
63
+ local_dir=model_dir,
64
+ force_download=False
65
+ )
66
+
67
+ # Verify the downloaded model
68
+ if verify_model_file(settings.MODEL_PATH):
69
+ logger.info("✅ Model downloaded and verified successfully")
70
+ return True
71
+ else:
72
+ logger.error("❌ Downloaded model failed verification")
73
+ return False
74
+ else:
75
+ logger.error(f"❌ Model file '{model_filename}' not found in repository")
76
+ logger.info("Available .pt files:")
77
+ pt_files = [f for f in available_files if f.endswith('.pt')]
78
+ for pt_file in pt_files:
79
+ logger.info(f" - {pt_file}")
80
+ return False
81
+
82
+ except Exception as e:
83
+ logger.error(f"❌ Failed to download model: {str(e)}")
84
+ return False
85
+
86
+
87
+ def start_api():
88
+ """Start the FastAPI application."""
89
+ import uvicorn
90
+
91
+ logger.info("🚀 Starting Marine Species Identification API...")
92
+ logger.info(f"Host: {settings.HOST}")
93
+ logger.info(f"Port: {settings.PORT}")
94
+ logger.info(f"Docs: http://{settings.HOST}:{settings.PORT}/docs")
95
+
96
+ uvicorn.run(
97
+ "app.main:app",
98
+ host=settings.HOST,
99
+ port=settings.PORT,
100
+ reload=False,
101
+ log_level="info",
102
+ access_log=True
103
+ )
104
+
105
+
106
+ async def main():
107
+ """Main startup function."""
108
+ logger.info("🐟 Marine Species Identification API Startup")
109
+ logger.info("=" * 50)
110
+
111
+ # Check model availability
112
+ model_available = await ensure_model_available()
113
+
114
+ if not model_available:
115
+ logger.warning("⚠️ Model not available - API will start but inference may fail")
116
+ logger.info("The API will still start and you can check /health for status")
117
+
118
+ logger.info("=" * 50)
119
+
120
+ # Start the API
121
+ start_api()
122
+
123
+
124
+ if __name__ == "__main__":
125
+ try:
126
+ asyncio.run(main())
127
+ except KeyboardInterrupt:
128
+ logger.info("🛑 API startup interrupted by user")
129
+ except Exception as e:
130
+ logger.error(f"❌ Failed to start API: {str(e)}")
131
+ sys.exit(1)
test_api_simple.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple test script for the Marine Species Identification API.
4
+ This script can be used to quickly test the API functionality.
5
+ """
6
+
7
+ import requests
8
+ import base64
9
+ import json
10
+ import time
11
+ from PIL import Image
12
+ import numpy as np
13
+ import io
14
+
15
+
16
+ def create_test_image(width: int = 640, height: int = 480) -> str:
17
+ """Create a test image and return as base64 string."""
18
+ # Create a simple test image with some patterns
19
+ image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
20
+
21
+ # Add some simple patterns to make it more interesting
22
+ image[100:200, 100:200] = [255, 0, 0] # Red square
23
+ image[300:400, 300:400] = [0, 255, 0] # Green square
24
+
25
+ pil_image = Image.fromarray(image)
26
+
27
+ # Convert to base64
28
+ buffer = io.BytesIO()
29
+ pil_image.save(buffer, format="JPEG", quality=85)
30
+ image_bytes = buffer.getvalue()
31
+
32
+ return base64.b64encode(image_bytes).decode('utf-8')
33
+
34
+
35
+ def test_api(base_url: str = "http://localhost:7860"):
36
+ """Test the API endpoints."""
37
+
38
+ print(f"🧪 Testing Marine Species Identification API at {base_url}")
39
+ print("=" * 60)
40
+
41
+ # Test 1: Root endpoint
42
+ print("1. Testing root endpoint...")
43
+ try:
44
+ response = requests.get(f"{base_url}/")
45
+ print(f" Status: {response.status_code}")
46
+ if response.status_code == 200:
47
+ print(f" Response: {response.json()}")
48
+ print()
49
+ except Exception as e:
50
+ print(f" Error: {e}")
51
+ return
52
+
53
+ # Test 2: Health check
54
+ print("2. Testing health check...")
55
+ try:
56
+ response = requests.get(f"{base_url}/api/v1/health")
57
+ print(f" Status: {response.status_code}")
58
+ if response.status_code == 200:
59
+ health_data = response.json()
60
+ print(f" API Status: {health_data.get('status')}")
61
+ print(f" Model Loaded: {health_data.get('model_loaded')}")
62
+ print()
63
+ except Exception as e:
64
+ print(f" Error: {e}")
65
+ print()
66
+
67
+ # Test 3: API info
68
+ print("3. Testing API info...")
69
+ try:
70
+ response = requests.get(f"{base_url}/api/v1/info")
71
+ print(f" Status: {response.status_code}")
72
+ if response.status_code == 200:
73
+ info_data = response.json()
74
+ print(f" API Name: {info_data.get('name')}")
75
+ print(f" Version: {info_data.get('version')}")
76
+ model_info = info_data.get('model_info', {})
77
+ print(f" Model Classes: {model_info.get('total_classes')}")
78
+ print()
79
+ except Exception as e:
80
+ print(f" Error: {e}")
81
+ print()
82
+
83
+ # Test 4: Species list
84
+ print("4. Testing species list...")
85
+ try:
86
+ response = requests.get(f"{base_url}/api/v1/species")
87
+ print(f" Status: {response.status_code}")
88
+ if response.status_code == 200:
89
+ species_data = response.json()
90
+ total_species = species_data.get('total_count', 0)
91
+ print(f" Total Species: {total_species}")
92
+ if total_species > 0:
93
+ print(f" First 3 species:")
94
+ for species in species_data.get('species', [])[:3]:
95
+ print(f" - {species.get('class_name')} (ID: {species.get('class_id')})")
96
+ print()
97
+ except Exception as e:
98
+ print(f" Error: {e}")
99
+ print()
100
+
101
+ # Test 5: Detection with test image
102
+ print("5. Testing marine species detection...")
103
+ try:
104
+ # Create a test image
105
+ print(" Creating test image...")
106
+ test_image_b64 = create_test_image()
107
+
108
+ # Prepare request
109
+ detection_request = {
110
+ "image": test_image_b64,
111
+ "confidence_threshold": 0.25,
112
+ "iou_threshold": 0.45,
113
+ "image_size": 640,
114
+ "return_annotated_image": True
115
+ }
116
+
117
+ print(" Sending detection request...")
118
+ start_time = time.time()
119
+
120
+ response = requests.post(
121
+ f"{base_url}/api/v1/detect",
122
+ json=detection_request,
123
+ timeout=30
124
+ )
125
+
126
+ end_time = time.time()
127
+ request_time = end_time - start_time
128
+
129
+ print(f" Status: {response.status_code}")
130
+ print(f" Request Time: {request_time:.2f}s")
131
+
132
+ if response.status_code == 200:
133
+ detection_data = response.json()
134
+ detections = detection_data.get('detections', [])
135
+ processing_time = detection_data.get('processing_time', 0)
136
+
137
+ print(f" Processing Time: {processing_time:.3f}s")
138
+ print(f" Detections Found: {len(detections)}")
139
+
140
+ if detections:
141
+ print(" Top detections:")
142
+ for i, detection in enumerate(detections[:3]):
143
+ print(f" {i+1}. {detection.get('class_name')} "
144
+ f"(confidence: {detection.get('confidence'):.3f})")
145
+
146
+ # Check if annotated image was returned
147
+ if detection_data.get('annotated_image'):
148
+ print(" ✅ Annotated image returned")
149
+ else:
150
+ print(" ❌ No annotated image returned")
151
+
152
+ elif response.status_code == 503:
153
+ print(" ⚠️ Service unavailable (model may not be loaded)")
154
+ else:
155
+ print(f" ❌ Error: {response.text}")
156
+
157
+ print()
158
+
159
+ except Exception as e:
160
+ print(f" Error: {e}")
161
+ print()
162
+
163
+ print("🎉 API testing completed!")
164
+ print("=" * 60)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ import sys
169
+
170
+ # Allow custom base URL
171
+ base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
172
+
173
+ test_api(base_url)
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Test package
tests/test_api.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic API tests for the Marine Species Identification API.
3
+ """
4
+
5
+ import pytest
6
+ import base64
7
+ import io
8
+ from PIL import Image
9
+ import numpy as np
10
+ from fastapi.testclient import TestClient
11
+
12
+ from app.main import app
13
+
14
+ client = TestClient(app)
15
+
16
+
17
+ def create_test_image(width: int = 640, height: int = 480) -> str:
18
+ """Create a test image and return as base64 string."""
19
+ # Create a simple test image
20
+ image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
21
+ pil_image = Image.fromarray(image)
22
+
23
+ # Convert to base64
24
+ buffer = io.BytesIO()
25
+ pil_image.save(buffer, format="JPEG")
26
+ image_bytes = buffer.getvalue()
27
+
28
+ return base64.b64encode(image_bytes).decode('utf-8')
29
+
30
+
31
+ class TestHealthEndpoints:
32
+ """Test health and status endpoints."""
33
+
34
+ def test_root_endpoint(self):
35
+ """Test root endpoint."""
36
+ response = client.get("/")
37
+ assert response.status_code == 200
38
+ data = response.json()
39
+ assert "message" in data
40
+ assert "version" in data
41
+
42
+ def test_root_health(self):
43
+ """Test root health endpoint."""
44
+ response = client.get("/health")
45
+ assert response.status_code == 200
46
+ data = response.json()
47
+ assert data["status"] == "ok"
48
+
49
+ def test_health_check(self):
50
+ """Test detailed health check."""
51
+ response = client.get("/api/v1/health")
52
+ assert response.status_code == 200
53
+ data = response.json()
54
+ assert "status" in data
55
+ assert "model_loaded" in data
56
+ assert "timestamp" in data
57
+
58
+ def test_api_info(self):
59
+ """Test API info endpoint."""
60
+ response = client.get("/api/v1/info")
61
+ assert response.status_code == 200
62
+ data = response.json()
63
+ assert "name" in data
64
+ assert "version" in data
65
+ assert "endpoints" in data
66
+
67
+ def test_liveness_check(self):
68
+ """Test liveness probe."""
69
+ response = client.get("/api/v1/live")
70
+ assert response.status_code == 200
71
+ data = response.json()
72
+ assert data["status"] == "alive"
73
+
74
+
75
+ class TestSpeciesEndpoints:
76
+ """Test species-related endpoints."""
77
+
78
+ def test_list_species(self):
79
+ """Test species list endpoint."""
80
+ response = client.get("/api/v1/species")
81
+ assert response.status_code in [200, 503] # May fail if model not loaded
82
+
83
+ if response.status_code == 200:
84
+ data = response.json()
85
+ assert "species" in data
86
+ assert "total_count" in data
87
+ assert isinstance(data["species"], list)
88
+
89
+ def test_get_species_info(self):
90
+ """Test individual species info endpoint."""
91
+ # This may fail if model is not loaded, which is expected in test environment
92
+ response = client.get("/api/v1/species/0")
93
+ assert response.status_code in [200, 404, 503]
94
+
95
+
96
+ class TestInferenceEndpoints:
97
+ """Test inference endpoints."""
98
+
99
+ def test_detect_invalid_image(self):
100
+ """Test detection with invalid image data."""
101
+ response = client.post(
102
+ "/api/v1/detect",
103
+ json={
104
+ "image": "invalid_base64_data",
105
+ "confidence_threshold": 0.25
106
+ }
107
+ )
108
+ assert response.status_code in [400, 503] # Bad request or service unavailable
109
+
110
+ def test_detect_valid_request_format(self):
111
+ """Test detection with valid request format."""
112
+ test_image = create_test_image()
113
+
114
+ response = client.post(
115
+ "/api/v1/detect",
116
+ json={
117
+ "image": test_image,
118
+ "confidence_threshold": 0.25,
119
+ "iou_threshold": 0.45,
120
+ "image_size": 640,
121
+ "return_annotated_image": True
122
+ }
123
+ )
124
+
125
+ # May return 503 if model is not loaded, which is expected in test environment
126
+ assert response.status_code in [200, 503]
127
+
128
+ if response.status_code == 200:
129
+ data = response.json()
130
+ assert "detections" in data
131
+ assert "processing_time" in data
132
+ assert "model_info" in data
133
+ assert "image_dimensions" in data
134
+
135
+ def test_detect_parameter_validation(self):
136
+ """Test parameter validation."""
137
+ test_image = create_test_image()
138
+
139
+ # Test invalid confidence threshold
140
+ response = client.post(
141
+ "/api/v1/detect",
142
+ json={
143
+ "image": test_image,
144
+ "confidence_threshold": 1.5 # Invalid: > 1.0
145
+ }
146
+ )
147
+ assert response.status_code == 422 # Validation error
148
+
149
+ # Test invalid image size
150
+ response = client.post(
151
+ "/api/v1/detect",
152
+ json={
153
+ "image": test_image,
154
+ "image_size": 100 # Invalid: < 320
155
+ }
156
+ )
157
+ assert response.status_code == 422 # Validation error
158
+
159
+
160
+ class TestErrorHandling:
161
+ """Test error handling."""
162
+
163
+ def test_404_endpoint(self):
164
+ """Test non-existent endpoint."""
165
+ response = client.get("/api/v1/nonexistent")
166
+ assert response.status_code == 404
167
+
168
+ def test_method_not_allowed(self):
169
+ """Test wrong HTTP method."""
170
+ response = client.get("/api/v1/detect") # Should be POST
171
+ assert response.status_code == 405
172
+
173
+
174
+ if __name__ == "__main__":
175
+ pytest.main([__file__])