kamau1 commited on
Commit
0c8b8e3
Β·
verified Β·
1 Parent(s): b5859c0

Fix Pydantic field conflicts and model loading issues; add test and upload scripts for HF deployment

Browse files
app/models/inference.py CHANGED
@@ -25,6 +25,8 @@ class Detection(BaseModel):
25
 
26
  class ModelInfo(BaseModel):
27
  """Model information."""
 
 
28
  model_name: str = Field(..., description="Name of the model")
29
  total_classes: int = Field(..., description="Total number of species classes")
30
  device: str = Field(..., description="Device used for inference")
@@ -102,6 +104,8 @@ class SpeciesListResponse(BaseModel):
102
 
103
  class HealthResponse(BaseModel):
104
  """Response model for health check."""
 
 
105
  status: str = Field(..., description="API status")
106
  model_loaded: bool = Field(..., description="Whether the model is loaded")
107
  model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
@@ -117,6 +121,8 @@ class ErrorResponse(BaseModel):
117
 
118
  class APIInfo(BaseModel):
119
  """API information response."""
 
 
120
  name: str = Field(..., description="API name")
121
  version: str = Field(..., description="API version")
122
  description: str = Field(..., description="API description")
 
25
 
26
  class ModelInfo(BaseModel):
27
  """Model information."""
28
+ model_config = {"protected_namespaces": ()}
29
+
30
  model_name: str = Field(..., description="Name of the model")
31
  total_classes: int = Field(..., description="Total number of species classes")
32
  device: str = Field(..., description="Device used for inference")
 
104
 
105
  class HealthResponse(BaseModel):
106
  """Response model for health check."""
107
+ model_config = {"protected_namespaces": ()}
108
+
109
  status: str = Field(..., description="API status")
110
  model_loaded: bool = Field(..., description="Whether the model is loaded")
111
  model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
 
121
 
122
  class APIInfo(BaseModel):
123
  """API information response."""
124
+ model_config = {"protected_namespaces": ()}
125
+
126
  name: str = Field(..., description="API name")
127
  version: str = Field(..., description="API version")
128
  description: str = Field(..., description="API description")
app/services/model_service.py CHANGED
@@ -27,14 +27,26 @@ class ModelService:
27
  Downloads from HuggingFace Hub if not present locally.
28
  """
29
  model_path = Path(settings.MODEL_PATH)
30
-
31
  # Check if model exists locally
32
  if not model_path.exists():
33
  logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
34
- await self._download_model()
35
-
 
 
 
 
36
  # Load class names if available
37
  await self._load_class_names()
 
 
 
 
 
 
 
 
38
 
39
  async def _download_model(self) -> None:
40
  """Download model from HuggingFace Hub."""
@@ -122,20 +134,38 @@ class ModelService:
122
  def get_model_info(self) -> Dict:
123
  """
124
  Get comprehensive model information.
125
-
126
  Returns:
127
  Dictionary with model information
128
  """
129
- model = self.get_model()
130
- class_names = self.get_class_names()
131
-
132
- return {
133
- "model_name": settings.MODEL_NAME,
134
- "total_classes": len(class_names) if class_names else 0,
135
- "device": model.device,
136
- "model_path": settings.MODEL_PATH,
137
- "huggingface_repo": settings.HUGGINGFACE_REPO
138
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  async def health_check(self) -> Dict:
141
  """
 
27
  Downloads from HuggingFace Hub if not present locally.
28
  """
29
  model_path = Path(settings.MODEL_PATH)
30
+
31
  # Check if model exists locally
32
  if not model_path.exists():
33
  logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
34
+ try:
35
+ await self._download_model()
36
+ except Exception as e:
37
+ logger.error(f"Failed to download model: {e}")
38
+ # Continue anyway - the API can still run without the model
39
+
40
  # Load class names if available
41
  await self._load_class_names()
42
+
43
+ # Try to initialize the model to catch loading errors early
44
+ try:
45
+ self.get_model()
46
+ logger.info("Model loaded successfully during startup")
47
+ except Exception as e:
48
+ logger.error(f"Model failed to load during startup: {e}")
49
+ # Don't fail startup - let health checks handle this
50
 
51
  async def _download_model(self) -> None:
52
  """Download model from HuggingFace Hub."""
 
134
  def get_model_info(self) -> Dict:
135
  """
136
  Get comprehensive model information.
137
+
138
  Returns:
139
  Dictionary with model information
140
  """
141
+ try:
142
+ model = self.get_model()
143
+ class_names = self.get_class_names()
144
+
145
+ # Safely get device info
146
+ device_info = "unknown"
147
+ try:
148
+ device_info = str(model.device) if hasattr(model, 'device') else "unknown"
149
+ except Exception as e:
150
+ logger.warning(f"Could not get device info: {e}")
151
+
152
+ return {
153
+ "model_name": settings.MODEL_NAME,
154
+ "total_classes": len(class_names) if class_names else 0,
155
+ "device": device_info,
156
+ "model_path": settings.MODEL_PATH,
157
+ "huggingface_repo": settings.HUGGINGFACE_REPO
158
+ }
159
+ except Exception as e:
160
+ logger.error(f"Failed to get model info: {str(e)}")
161
+ # Return basic info even if model fails
162
+ return {
163
+ "model_name": settings.MODEL_NAME,
164
+ "total_classes": 0,
165
+ "device": "unknown",
166
+ "model_path": settings.MODEL_PATH,
167
+ "huggingface_repo": settings.HUGGINGFACE_REPO
168
+ }
169
 
170
  async def health_check(self) -> Dict:
171
  """
test_with_images.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test the Marine Species Identification API with real marine species images.
4
+ """
5
+
6
+ import requests
7
+ import base64
8
+ import json
9
+ import time
10
+ from pathlib import Path
11
+ import sys
12
+
13
+
14
+ def encode_image_to_base64(image_path: str) -> str:
15
+ """Encode an image file to base64 string."""
16
+ try:
17
+ with open(image_path, "rb") as image_file:
18
+ return base64.b64encode(image_file.read()).decode('utf-8')
19
+ except Exception as e:
20
+ print(f"Error encoding image {image_path}: {e}")
21
+ return None
22
+
23
+
24
+ def test_api_with_image(api_url: str, image_path: str, image_name: str):
25
+ """Test the API with a specific image."""
26
+ print(f"\n🐟 Testing with {image_name}")
27
+ print("-" * 50)
28
+
29
+ # Encode image
30
+ print("πŸ“· Encoding image...")
31
+ image_b64 = encode_image_to_base64(image_path)
32
+ if not image_b64:
33
+ print("❌ Failed to encode image")
34
+ return False
35
+
36
+ # Prepare request
37
+ detection_request = {
38
+ "image": image_b64,
39
+ "confidence_threshold": 0.25,
40
+ "iou_threshold": 0.45,
41
+ "image_size": 640,
42
+ "return_annotated_image": True
43
+ }
44
+
45
+ try:
46
+ print("πŸ” Sending detection request...")
47
+ start_time = time.time()
48
+
49
+ response = requests.post(
50
+ f"{api_url}/api/v1/detect",
51
+ json=detection_request,
52
+ timeout=60
53
+ )
54
+
55
+ request_time = time.time() - start_time
56
+ print(f"⏱️ Request completed in {request_time:.2f}s")
57
+ print(f"πŸ“Š Status Code: {response.status_code}")
58
+
59
+ if response.status_code == 200:
60
+ result = response.json()
61
+
62
+ detections = result.get('detections', [])
63
+ processing_time = result.get('processing_time', 0)
64
+ image_dims = result.get('image_dimensions', {})
65
+
66
+ print(f"βœ… SUCCESS!")
67
+ print(f" Processing Time: {processing_time:.3f}s")
68
+ print(f" Image Dimensions: {image_dims.get('width')}x{image_dims.get('height')}")
69
+ print(f" Detections Found: {len(detections)}")
70
+
71
+ if detections:
72
+ print(f" 🎯 Top Detections:")
73
+ for i, detection in enumerate(detections[:5]): # Show top 5
74
+ species = detection.get('class_name', 'Unknown')
75
+ confidence = detection.get('confidence', 0)
76
+ bbox = detection.get('bbox', {})
77
+ print(f" {i+1}. {species} (confidence: {confidence:.3f})")
78
+ print(f" Box: x={bbox.get('x', 0):.0f}, y={bbox.get('y', 0):.0f}, "
79
+ f"w={bbox.get('width', 0):.0f}, h={bbox.get('height', 0):.0f}")
80
+ else:
81
+ print(" ℹ️ No marine species detected")
82
+
83
+ # Check annotated image
84
+ if result.get('annotated_image'):
85
+ print(" πŸ–ΌοΈ Annotated image returned")
86
+
87
+ return True
88
+
89
+ elif response.status_code == 503:
90
+ print("❌ Service Unavailable - Model may not be loaded")
91
+ print(f" Response: {response.text}")
92
+ return False
93
+
94
+ else:
95
+ print(f"❌ Request failed with status {response.status_code}")
96
+ try:
97
+ error_detail = response.json()
98
+ print(f" Error: {error_detail}")
99
+ except:
100
+ print(f" Response: {response.text}")
101
+ return False
102
+
103
+ except requests.exceptions.Timeout:
104
+ print("❌ Request timed out")
105
+ return False
106
+ except Exception as e:
107
+ print(f"❌ Request failed: {e}")
108
+ return False
109
+
110
+
111
+ def main():
112
+ """Main test function."""
113
+ # API URL - can be overridden with command line argument
114
+ api_url = sys.argv[1] if len(sys.argv) > 1 else "https://seamo-ai-fishapi.hf.space"
115
+
116
+ print("🐟 Marine Species Identification API - Image Testing")
117
+ print("=" * 60)
118
+ print(f"API URL: {api_url}")
119
+
120
+ # Check API health first
121
+ print("\nπŸ₯ Checking API health...")
122
+ try:
123
+ health_response = requests.get(f"{api_url}/api/v1/health", timeout=10)
124
+ if health_response.status_code == 200:
125
+ health_data = health_response.json()
126
+ print(f" Status: {health_data.get('status')}")
127
+ print(f" Model Loaded: {health_data.get('model_loaded')}")
128
+ else:
129
+ print(f" Health check failed: {health_response.status_code}")
130
+ except Exception as e:
131
+ print(f" Health check error: {e}")
132
+
133
+ # Find test images
134
+ image_dir = Path("docs/gradio/images")
135
+ if not image_dir.exists():
136
+ print(f"\n❌ Image directory not found: {image_dir}")
137
+ print("Please make sure you're running this from the project root directory.")
138
+ return
139
+
140
+ # Get all image files
141
+ image_files = []
142
+ for ext in ['*.png', '*.jpg', '*.jpeg']:
143
+ image_files.extend(image_dir.glob(ext))
144
+
145
+ if not image_files:
146
+ print(f"\n❌ No image files found in {image_dir}")
147
+ return
148
+
149
+ print(f"\nπŸ“ Found {len(image_files)} test images")
150
+
151
+ # Test each image
152
+ successful_tests = 0
153
+ total_tests = len(image_files)
154
+
155
+ for image_path in sorted(image_files):
156
+ success = test_api_with_image(api_url, str(image_path), image_path.name)
157
+ if success:
158
+ successful_tests += 1
159
+
160
+ # Small delay between requests
161
+ time.sleep(1)
162
+
163
+ # Summary
164
+ print("\n" + "=" * 60)
165
+ print(f"🎯 Test Summary:")
166
+ print(f" Total Tests: {total_tests}")
167
+ print(f" Successful: {successful_tests}")
168
+ print(f" Failed: {total_tests - successful_tests}")
169
+ print(f" Success Rate: {(successful_tests/total_tests)*100:.1f}%")
170
+
171
+ if successful_tests == total_tests:
172
+ print("πŸŽ‰ All tests passed!")
173
+ elif successful_tests > 0:
174
+ print("⚠️ Some tests passed, some failed")
175
+ else:
176
+ print("❌ All tests failed")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()
upload_model_to_hf.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to upload the marine species model files to HuggingFace Hub.
4
+ This script helps you upload your model files to make them available for the API.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+ from huggingface_hub import HfApi, create_repo
11
+ import argparse
12
+
13
+
14
+ def upload_model_files(repo_id: str, local_model_dir: str, token: str = None):
15
+ """
16
+ Upload model files to HuggingFace Hub.
17
+
18
+ Args:
19
+ repo_id: HuggingFace repository ID (e.g., "username/repo-name")
20
+ local_model_dir: Local directory containing model files
21
+ token: HuggingFace token (optional, will use saved token if not provided)
22
+ """
23
+
24
+ print(f"πŸš€ Uploading model files to HuggingFace Hub")
25
+ print(f"Repository: {repo_id}")
26
+ print(f"Local directory: {local_model_dir}")
27
+ print("=" * 60)
28
+
29
+ # Initialize HF API
30
+ api = HfApi(token=token)
31
+
32
+ # Check if local directory exists
33
+ local_path = Path(local_model_dir)
34
+ if not local_path.exists():
35
+ print(f"❌ Local directory not found: {local_path}")
36
+ return False
37
+
38
+ # Find model files
39
+ model_files = []
40
+ for pattern in ["*.pt", "*.pth", "*.names", "*.txt"]:
41
+ model_files.extend(local_path.glob(pattern))
42
+
43
+ if not model_files:
44
+ print(f"❌ No model files found in {local_path}")
45
+ print("Looking for: *.pt, *.pth, *.names, *.txt files")
46
+ return False
47
+
48
+ print(f"πŸ“ Found {len(model_files)} files to upload:")
49
+ for file in model_files:
50
+ size_mb = file.stat().st_size / (1024 * 1024)
51
+ print(f" - {file.name} ({size_mb:.1f} MB)")
52
+
53
+ try:
54
+ # Create repository if it doesn't exist
55
+ print(f"\nπŸ—οΈ Creating/checking repository...")
56
+ try:
57
+ create_repo(repo_id, exist_ok=True, token=token)
58
+ print(f"βœ… Repository ready: {repo_id}")
59
+ except Exception as e:
60
+ print(f"⚠️ Repository creation warning: {e}")
61
+
62
+ # Upload each file
63
+ print(f"\nπŸ“€ Uploading files...")
64
+ for file_path in model_files:
65
+ print(f" Uploading {file_path.name}...")
66
+
67
+ api.upload_file(
68
+ path_or_fileobj=str(file_path),
69
+ path_in_repo=file_path.name,
70
+ repo_id=repo_id,
71
+ token=token
72
+ )
73
+ print(f" βœ… {file_path.name} uploaded successfully")
74
+
75
+ print(f"\nπŸŽ‰ All files uploaded successfully!")
76
+ print(f"πŸ”— Repository URL: https://huggingface.co/{repo_id}")
77
+ return True
78
+
79
+ except Exception as e:
80
+ print(f"❌ Upload failed: {e}")
81
+ return False
82
+
83
+
84
+ def main():
85
+ """Main function."""
86
+ parser = argparse.ArgumentParser(description="Upload marine species model to HuggingFace Hub")
87
+ parser.add_argument("--repo-id", required=True, help="HuggingFace repository ID (e.g., username/repo-name)")
88
+ parser.add_argument("--model-dir", default="docs/gradio", help="Local directory containing model files")
89
+ parser.add_argument("--token", help="HuggingFace token (optional)")
90
+
91
+ args = parser.parse_args()
92
+
93
+ print("🐟 Marine Species Model Upload Tool")
94
+ print("=" * 50)
95
+
96
+ # Check if user is logged in to HuggingFace
97
+ if not args.token:
98
+ print("πŸ’‘ Tip: Make sure you're logged in to HuggingFace CLI:")
99
+ print(" huggingface-cli login")
100
+ print()
101
+
102
+ # Upload files
103
+ success = upload_model_files(
104
+ repo_id=args.repo_id,
105
+ local_model_dir=args.model_dir,
106
+ token=args.token
107
+ )
108
+
109
+ if success:
110
+ print(f"\nβœ… Upload completed successfully!")
111
+ print(f"πŸ”§ Next steps:")
112
+ print(f" 1. Update your API configuration to use: {args.repo_id}")
113
+ print(f" 2. Redeploy your API")
114
+ print(f" 3. Test the API with: python test_with_images.py")
115
+ sys.exit(0)
116
+ else:
117
+ print(f"\n❌ Upload failed!")
118
+ sys.exit(1)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()