Fix Pydantic field conflicts and model loading issues; add test and upload scripts for HF deployment
Browse files- app/models/inference.py +6 -0
- app/services/model_service.py +44 -14
- test_with_images.py +180 -0
- upload_model_to_hf.py +122 -0
app/models/inference.py
CHANGED
|
@@ -25,6 +25,8 @@ class Detection(BaseModel):
|
|
| 25 |
|
| 26 |
class ModelInfo(BaseModel):
|
| 27 |
"""Model information."""
|
|
|
|
|
|
|
| 28 |
model_name: str = Field(..., description="Name of the model")
|
| 29 |
total_classes: int = Field(..., description="Total number of species classes")
|
| 30 |
device: str = Field(..., description="Device used for inference")
|
|
@@ -102,6 +104,8 @@ class SpeciesListResponse(BaseModel):
|
|
| 102 |
|
| 103 |
class HealthResponse(BaseModel):
|
| 104 |
"""Response model for health check."""
|
|
|
|
|
|
|
| 105 |
status: str = Field(..., description="API status")
|
| 106 |
model_loaded: bool = Field(..., description="Whether the model is loaded")
|
| 107 |
model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
|
|
@@ -117,6 +121,8 @@ class ErrorResponse(BaseModel):
|
|
| 117 |
|
| 118 |
class APIInfo(BaseModel):
|
| 119 |
"""API information response."""
|
|
|
|
|
|
|
| 120 |
name: str = Field(..., description="API name")
|
| 121 |
version: str = Field(..., description="API version")
|
| 122 |
description: str = Field(..., description="API description")
|
|
|
|
| 25 |
|
| 26 |
class ModelInfo(BaseModel):
|
| 27 |
"""Model information."""
|
| 28 |
+
model_config = {"protected_namespaces": ()}
|
| 29 |
+
|
| 30 |
model_name: str = Field(..., description="Name of the model")
|
| 31 |
total_classes: int = Field(..., description="Total number of species classes")
|
| 32 |
device: str = Field(..., description="Device used for inference")
|
|
|
|
| 104 |
|
| 105 |
class HealthResponse(BaseModel):
|
| 106 |
"""Response model for health check."""
|
| 107 |
+
model_config = {"protected_namespaces": ()}
|
| 108 |
+
|
| 109 |
status: str = Field(..., description="API status")
|
| 110 |
model_loaded: bool = Field(..., description="Whether the model is loaded")
|
| 111 |
model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
|
|
|
|
| 121 |
|
| 122 |
class APIInfo(BaseModel):
|
| 123 |
"""API information response."""
|
| 124 |
+
model_config = {"protected_namespaces": ()}
|
| 125 |
+
|
| 126 |
name: str = Field(..., description="API name")
|
| 127 |
version: str = Field(..., description="API version")
|
| 128 |
description: str = Field(..., description="API description")
|
app/services/model_service.py
CHANGED
|
@@ -27,14 +27,26 @@ class ModelService:
|
|
| 27 |
Downloads from HuggingFace Hub if not present locally.
|
| 28 |
"""
|
| 29 |
model_path = Path(settings.MODEL_PATH)
|
| 30 |
-
|
| 31 |
# Check if model exists locally
|
| 32 |
if not model_path.exists():
|
| 33 |
logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Load class names if available
|
| 37 |
await self._load_class_names()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
async def _download_model(self) -> None:
|
| 40 |
"""Download model from HuggingFace Hub."""
|
|
@@ -122,20 +134,38 @@ class ModelService:
|
|
| 122 |
def get_model_info(self) -> Dict:
|
| 123 |
"""
|
| 124 |
Get comprehensive model information.
|
| 125 |
-
|
| 126 |
Returns:
|
| 127 |
Dictionary with model information
|
| 128 |
"""
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
async def health_check(self) -> Dict:
|
| 141 |
"""
|
|
|
|
| 27 |
Downloads from HuggingFace Hub if not present locally.
|
| 28 |
"""
|
| 29 |
model_path = Path(settings.MODEL_PATH)
|
| 30 |
+
|
| 31 |
# Check if model exists locally
|
| 32 |
if not model_path.exists():
|
| 33 |
logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
|
| 34 |
+
try:
|
| 35 |
+
await self._download_model()
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.error(f"Failed to download model: {e}")
|
| 38 |
+
# Continue anyway - the API can still run without the model
|
| 39 |
+
|
| 40 |
# Load class names if available
|
| 41 |
await self._load_class_names()
|
| 42 |
+
|
| 43 |
+
# Try to initialize the model to catch loading errors early
|
| 44 |
+
try:
|
| 45 |
+
self.get_model()
|
| 46 |
+
logger.info("Model loaded successfully during startup")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Model failed to load during startup: {e}")
|
| 49 |
+
# Don't fail startup - let health checks handle this
|
| 50 |
|
| 51 |
async def _download_model(self) -> None:
|
| 52 |
"""Download model from HuggingFace Hub."""
|
|
|
|
| 134 |
def get_model_info(self) -> Dict:
|
| 135 |
"""
|
| 136 |
Get comprehensive model information.
|
| 137 |
+
|
| 138 |
Returns:
|
| 139 |
Dictionary with model information
|
| 140 |
"""
|
| 141 |
+
try:
|
| 142 |
+
model = self.get_model()
|
| 143 |
+
class_names = self.get_class_names()
|
| 144 |
+
|
| 145 |
+
# Safely get device info
|
| 146 |
+
device_info = "unknown"
|
| 147 |
+
try:
|
| 148 |
+
device_info = str(model.device) if hasattr(model, 'device') else "unknown"
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.warning(f"Could not get device info: {e}")
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"model_name": settings.MODEL_NAME,
|
| 154 |
+
"total_classes": len(class_names) if class_names else 0,
|
| 155 |
+
"device": device_info,
|
| 156 |
+
"model_path": settings.MODEL_PATH,
|
| 157 |
+
"huggingface_repo": settings.HUGGINGFACE_REPO
|
| 158 |
+
}
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error(f"Failed to get model info: {str(e)}")
|
| 161 |
+
# Return basic info even if model fails
|
| 162 |
+
return {
|
| 163 |
+
"model_name": settings.MODEL_NAME,
|
| 164 |
+
"total_classes": 0,
|
| 165 |
+
"device": "unknown",
|
| 166 |
+
"model_path": settings.MODEL_PATH,
|
| 167 |
+
"huggingface_repo": settings.HUGGINGFACE_REPO
|
| 168 |
+
}
|
| 169 |
|
| 170 |
async def health_check(self) -> Dict:
|
| 171 |
"""
|
test_with_images.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test the Marine Species Identification API with real marine species images.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import base64
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def encode_image_to_base64(image_path: str) -> str:
|
| 15 |
+
"""Encode an image file to base64 string."""
|
| 16 |
+
try:
|
| 17 |
+
with open(image_path, "rb") as image_file:
|
| 18 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"Error encoding image {image_path}: {e}")
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_api_with_image(api_url: str, image_path: str, image_name: str):
|
| 25 |
+
"""Test the API with a specific image."""
|
| 26 |
+
print(f"\nπ Testing with {image_name}")
|
| 27 |
+
print("-" * 50)
|
| 28 |
+
|
| 29 |
+
# Encode image
|
| 30 |
+
print("π· Encoding image...")
|
| 31 |
+
image_b64 = encode_image_to_base64(image_path)
|
| 32 |
+
if not image_b64:
|
| 33 |
+
print("β Failed to encode image")
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
# Prepare request
|
| 37 |
+
detection_request = {
|
| 38 |
+
"image": image_b64,
|
| 39 |
+
"confidence_threshold": 0.25,
|
| 40 |
+
"iou_threshold": 0.45,
|
| 41 |
+
"image_size": 640,
|
| 42 |
+
"return_annotated_image": True
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
print("π Sending detection request...")
|
| 47 |
+
start_time = time.time()
|
| 48 |
+
|
| 49 |
+
response = requests.post(
|
| 50 |
+
f"{api_url}/api/v1/detect",
|
| 51 |
+
json=detection_request,
|
| 52 |
+
timeout=60
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
request_time = time.time() - start_time
|
| 56 |
+
print(f"β±οΈ Request completed in {request_time:.2f}s")
|
| 57 |
+
print(f"π Status Code: {response.status_code}")
|
| 58 |
+
|
| 59 |
+
if response.status_code == 200:
|
| 60 |
+
result = response.json()
|
| 61 |
+
|
| 62 |
+
detections = result.get('detections', [])
|
| 63 |
+
processing_time = result.get('processing_time', 0)
|
| 64 |
+
image_dims = result.get('image_dimensions', {})
|
| 65 |
+
|
| 66 |
+
print(f"β
SUCCESS!")
|
| 67 |
+
print(f" Processing Time: {processing_time:.3f}s")
|
| 68 |
+
print(f" Image Dimensions: {image_dims.get('width')}x{image_dims.get('height')}")
|
| 69 |
+
print(f" Detections Found: {len(detections)}")
|
| 70 |
+
|
| 71 |
+
if detections:
|
| 72 |
+
print(f" π― Top Detections:")
|
| 73 |
+
for i, detection in enumerate(detections[:5]): # Show top 5
|
| 74 |
+
species = detection.get('class_name', 'Unknown')
|
| 75 |
+
confidence = detection.get('confidence', 0)
|
| 76 |
+
bbox = detection.get('bbox', {})
|
| 77 |
+
print(f" {i+1}. {species} (confidence: {confidence:.3f})")
|
| 78 |
+
print(f" Box: x={bbox.get('x', 0):.0f}, y={bbox.get('y', 0):.0f}, "
|
| 79 |
+
f"w={bbox.get('width', 0):.0f}, h={bbox.get('height', 0):.0f}")
|
| 80 |
+
else:
|
| 81 |
+
print(" βΉοΈ No marine species detected")
|
| 82 |
+
|
| 83 |
+
# Check annotated image
|
| 84 |
+
if result.get('annotated_image'):
|
| 85 |
+
print(" πΌοΈ Annotated image returned")
|
| 86 |
+
|
| 87 |
+
return True
|
| 88 |
+
|
| 89 |
+
elif response.status_code == 503:
|
| 90 |
+
print("β Service Unavailable - Model may not be loaded")
|
| 91 |
+
print(f" Response: {response.text}")
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
print(f"β Request failed with status {response.status_code}")
|
| 96 |
+
try:
|
| 97 |
+
error_detail = response.json()
|
| 98 |
+
print(f" Error: {error_detail}")
|
| 99 |
+
except:
|
| 100 |
+
print(f" Response: {response.text}")
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
except requests.exceptions.Timeout:
|
| 104 |
+
print("β Request timed out")
|
| 105 |
+
return False
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"β Request failed: {e}")
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
"""Main test function."""
|
| 113 |
+
# API URL - can be overridden with command line argument
|
| 114 |
+
api_url = sys.argv[1] if len(sys.argv) > 1 else "https://seamo-ai-fishapi.hf.space"
|
| 115 |
+
|
| 116 |
+
print("π Marine Species Identification API - Image Testing")
|
| 117 |
+
print("=" * 60)
|
| 118 |
+
print(f"API URL: {api_url}")
|
| 119 |
+
|
| 120 |
+
# Check API health first
|
| 121 |
+
print("\nπ₯ Checking API health...")
|
| 122 |
+
try:
|
| 123 |
+
health_response = requests.get(f"{api_url}/api/v1/health", timeout=10)
|
| 124 |
+
if health_response.status_code == 200:
|
| 125 |
+
health_data = health_response.json()
|
| 126 |
+
print(f" Status: {health_data.get('status')}")
|
| 127 |
+
print(f" Model Loaded: {health_data.get('model_loaded')}")
|
| 128 |
+
else:
|
| 129 |
+
print(f" Health check failed: {health_response.status_code}")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f" Health check error: {e}")
|
| 132 |
+
|
| 133 |
+
# Find test images
|
| 134 |
+
image_dir = Path("docs/gradio/images")
|
| 135 |
+
if not image_dir.exists():
|
| 136 |
+
print(f"\nβ Image directory not found: {image_dir}")
|
| 137 |
+
print("Please make sure you're running this from the project root directory.")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# Get all image files
|
| 141 |
+
image_files = []
|
| 142 |
+
for ext in ['*.png', '*.jpg', '*.jpeg']:
|
| 143 |
+
image_files.extend(image_dir.glob(ext))
|
| 144 |
+
|
| 145 |
+
if not image_files:
|
| 146 |
+
print(f"\nβ No image files found in {image_dir}")
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
print(f"\nπ Found {len(image_files)} test images")
|
| 150 |
+
|
| 151 |
+
# Test each image
|
| 152 |
+
successful_tests = 0
|
| 153 |
+
total_tests = len(image_files)
|
| 154 |
+
|
| 155 |
+
for image_path in sorted(image_files):
|
| 156 |
+
success = test_api_with_image(api_url, str(image_path), image_path.name)
|
| 157 |
+
if success:
|
| 158 |
+
successful_tests += 1
|
| 159 |
+
|
| 160 |
+
# Small delay between requests
|
| 161 |
+
time.sleep(1)
|
| 162 |
+
|
| 163 |
+
# Summary
|
| 164 |
+
print("\n" + "=" * 60)
|
| 165 |
+
print(f"π― Test Summary:")
|
| 166 |
+
print(f" Total Tests: {total_tests}")
|
| 167 |
+
print(f" Successful: {successful_tests}")
|
| 168 |
+
print(f" Failed: {total_tests - successful_tests}")
|
| 169 |
+
print(f" Success Rate: {(successful_tests/total_tests)*100:.1f}%")
|
| 170 |
+
|
| 171 |
+
if successful_tests == total_tests:
|
| 172 |
+
print("π All tests passed!")
|
| 173 |
+
elif successful_tests > 0:
|
| 174 |
+
print("β οΈ Some tests passed, some failed")
|
| 175 |
+
else:
|
| 176 |
+
print("β All tests failed")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
main()
|
upload_model_to_hf.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to upload the marine species model files to HuggingFace Hub.
|
| 4 |
+
This script helps you upload your model files to make them available for the API.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from huggingface_hub import HfApi, create_repo
|
| 11 |
+
import argparse
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def upload_model_files(repo_id: str, local_model_dir: str, token: str = None):
|
| 15 |
+
"""
|
| 16 |
+
Upload model files to HuggingFace Hub.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
repo_id: HuggingFace repository ID (e.g., "username/repo-name")
|
| 20 |
+
local_model_dir: Local directory containing model files
|
| 21 |
+
token: HuggingFace token (optional, will use saved token if not provided)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
print(f"π Uploading model files to HuggingFace Hub")
|
| 25 |
+
print(f"Repository: {repo_id}")
|
| 26 |
+
print(f"Local directory: {local_model_dir}")
|
| 27 |
+
print("=" * 60)
|
| 28 |
+
|
| 29 |
+
# Initialize HF API
|
| 30 |
+
api = HfApi(token=token)
|
| 31 |
+
|
| 32 |
+
# Check if local directory exists
|
| 33 |
+
local_path = Path(local_model_dir)
|
| 34 |
+
if not local_path.exists():
|
| 35 |
+
print(f"β Local directory not found: {local_path}")
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
# Find model files
|
| 39 |
+
model_files = []
|
| 40 |
+
for pattern in ["*.pt", "*.pth", "*.names", "*.txt"]:
|
| 41 |
+
model_files.extend(local_path.glob(pattern))
|
| 42 |
+
|
| 43 |
+
if not model_files:
|
| 44 |
+
print(f"β No model files found in {local_path}")
|
| 45 |
+
print("Looking for: *.pt, *.pth, *.names, *.txt files")
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
print(f"π Found {len(model_files)} files to upload:")
|
| 49 |
+
for file in model_files:
|
| 50 |
+
size_mb = file.stat().st_size / (1024 * 1024)
|
| 51 |
+
print(f" - {file.name} ({size_mb:.1f} MB)")
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Create repository if it doesn't exist
|
| 55 |
+
print(f"\nποΈ Creating/checking repository...")
|
| 56 |
+
try:
|
| 57 |
+
create_repo(repo_id, exist_ok=True, token=token)
|
| 58 |
+
print(f"β
Repository ready: {repo_id}")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"β οΈ Repository creation warning: {e}")
|
| 61 |
+
|
| 62 |
+
# Upload each file
|
| 63 |
+
print(f"\nπ€ Uploading files...")
|
| 64 |
+
for file_path in model_files:
|
| 65 |
+
print(f" Uploading {file_path.name}...")
|
| 66 |
+
|
| 67 |
+
api.upload_file(
|
| 68 |
+
path_or_fileobj=str(file_path),
|
| 69 |
+
path_in_repo=file_path.name,
|
| 70 |
+
repo_id=repo_id,
|
| 71 |
+
token=token
|
| 72 |
+
)
|
| 73 |
+
print(f" β
{file_path.name} uploaded successfully")
|
| 74 |
+
|
| 75 |
+
print(f"\nπ All files uploaded successfully!")
|
| 76 |
+
print(f"π Repository URL: https://huggingface.co/{repo_id}")
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"β Upload failed: {e}")
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
"""Main function."""
|
| 86 |
+
parser = argparse.ArgumentParser(description="Upload marine species model to HuggingFace Hub")
|
| 87 |
+
parser.add_argument("--repo-id", required=True, help="HuggingFace repository ID (e.g., username/repo-name)")
|
| 88 |
+
parser.add_argument("--model-dir", default="docs/gradio", help="Local directory containing model files")
|
| 89 |
+
parser.add_argument("--token", help="HuggingFace token (optional)")
|
| 90 |
+
|
| 91 |
+
args = parser.parse_args()
|
| 92 |
+
|
| 93 |
+
print("π Marine Species Model Upload Tool")
|
| 94 |
+
print("=" * 50)
|
| 95 |
+
|
| 96 |
+
# Check if user is logged in to HuggingFace
|
| 97 |
+
if not args.token:
|
| 98 |
+
print("π‘ Tip: Make sure you're logged in to HuggingFace CLI:")
|
| 99 |
+
print(" huggingface-cli login")
|
| 100 |
+
print()
|
| 101 |
+
|
| 102 |
+
# Upload files
|
| 103 |
+
success = upload_model_files(
|
| 104 |
+
repo_id=args.repo_id,
|
| 105 |
+
local_model_dir=args.model_dir,
|
| 106 |
+
token=args.token
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if success:
|
| 110 |
+
print(f"\nβ
Upload completed successfully!")
|
| 111 |
+
print(f"π§ Next steps:")
|
| 112 |
+
print(f" 1. Update your API configuration to use: {args.repo_id}")
|
| 113 |
+
print(f" 2. Redeploy your API")
|
| 114 |
+
print(f" 3. Test the API with: python test_with_images.py")
|
| 115 |
+
sys.exit(0)
|
| 116 |
+
else:
|
| 117 |
+
print(f"\nβ Upload failed!")
|
| 118 |
+
sys.exit(1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|