|
|
""" |
|
|
Model utilities for downloading and managing the marine species model. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any |
|
|
from huggingface_hub import hf_hub_download, list_repo_files |
|
|
|
|
|
from app.core.config import settings |
|
|
from app.core.logging import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def download_model_from_hf( |
|
|
repo_id: str, |
|
|
model_filename: str, |
|
|
local_dir: str, |
|
|
force_download: bool = False |
|
|
) -> str: |
|
|
""" |
|
|
Download model from HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
repo_id: HuggingFace repository ID |
|
|
model_filename: Name of the model file |
|
|
local_dir: Local directory to save the model |
|
|
force_download: Whether to force re-download if file exists |
|
|
|
|
|
Returns: |
|
|
Path to the downloaded model file |
|
|
""" |
|
|
try: |
|
|
|
|
|
Path(local_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
local_path = Path(local_dir) / model_filename |
|
|
|
|
|
|
|
|
if local_path.exists() and not force_download: |
|
|
logger.info(f"Model already exists at {local_path}") |
|
|
return str(local_path) |
|
|
|
|
|
logger.info(f"Downloading {model_filename} from {repo_id}...") |
|
|
|
|
|
downloaded_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=model_filename, |
|
|
local_dir=local_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
force_download=force_download |
|
|
) |
|
|
|
|
|
logger.info(f"Model downloaded successfully to: {downloaded_path}") |
|
|
return downloaded_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download model: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
def list_available_files(repo_id: str) -> list: |
|
|
""" |
|
|
List all available files in a HuggingFace repository. |
|
|
|
|
|
Args: |
|
|
repo_id: HuggingFace repository ID |
|
|
|
|
|
Returns: |
|
|
List of available files |
|
|
""" |
|
|
try: |
|
|
files = list_repo_files(repo_id) |
|
|
return files |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to list repository files: {str(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def verify_model_file(model_path: str) -> bool: |
|
|
""" |
|
|
Verify that a model file exists and is valid. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model file |
|
|
|
|
|
Returns: |
|
|
True if model file is valid |
|
|
""" |
|
|
try: |
|
|
path = Path(model_path) |
|
|
|
|
|
|
|
|
if not path.exists(): |
|
|
logger.error(f"Model file does not exist: {model_path}") |
|
|
return False |
|
|
|
|
|
|
|
|
file_size = path.stat().st_size |
|
|
if file_size < 1024 * 1024: |
|
|
logger.warning(f"Model file seems too small: {file_size} bytes") |
|
|
return False |
|
|
|
|
|
|
|
|
if not path.suffix.lower() in ['.pt', '.pth']: |
|
|
logger.warning(f"Unexpected model file extension: {path.suffix}") |
|
|
|
|
|
logger.info(f"Model file verified: {model_path} ({file_size / (1024*1024):.1f} MB)") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to verify model file: {str(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_model_info(model_path: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about a model file. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model file |
|
|
|
|
|
Returns: |
|
|
Dictionary with model information |
|
|
""" |
|
|
info = { |
|
|
"path": model_path, |
|
|
"exists": False, |
|
|
"size_mb": 0, |
|
|
"size_bytes": 0 |
|
|
} |
|
|
|
|
|
try: |
|
|
path = Path(model_path) |
|
|
|
|
|
if path.exists(): |
|
|
info["exists"] = True |
|
|
size_bytes = path.stat().st_size |
|
|
info["size_bytes"] = size_bytes |
|
|
info["size_mb"] = size_bytes / (1024 * 1024) |
|
|
info["modified_time"] = path.stat().st_mtime |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get model info: {str(e)}") |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
def cleanup_model_cache(cache_dir: Optional[str] = None) -> None: |
|
|
""" |
|
|
Clean up model cache directory. |
|
|
|
|
|
Args: |
|
|
cache_dir: Cache directory to clean (uses default if None) |
|
|
""" |
|
|
try: |
|
|
if cache_dir is None: |
|
|
cache_dir = Path.home() / ".cache" / "huggingface" |
|
|
|
|
|
cache_path = Path(cache_dir) |
|
|
|
|
|
if cache_path.exists(): |
|
|
logger.info(f"Cleaning up cache directory: {cache_path}") |
|
|
shutil.rmtree(cache_path) |
|
|
logger.info("Cache cleanup completed") |
|
|
else: |
|
|
logger.info("Cache directory does not exist") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to cleanup cache: {str(e)}") |
|
|
|
|
|
|
|
|
def setup_model_directory() -> str: |
|
|
""" |
|
|
Setup the model directory and ensure it exists. |
|
|
|
|
|
Returns: |
|
|
Path to the model directory |
|
|
""" |
|
|
model_dir = Path(settings.MODEL_PATH).parent |
|
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Model directory setup: {model_dir}") |
|
|
return str(model_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Model management utility") |
|
|
parser.add_argument("--download", action="store_true", help="Download model from HuggingFace") |
|
|
parser.add_argument("--verify", action="store_true", help="Verify model file") |
|
|
parser.add_argument("--info", action="store_true", help="Show model information") |
|
|
parser.add_argument("--list-files", action="store_true", help="List available files in HF repo") |
|
|
parser.add_argument("--cleanup-cache", action="store_true", help="Cleanup model cache") |
|
|
parser.add_argument("--force", action="store_true", help="Force download even if file exists") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.download: |
|
|
setup_model_directory() |
|
|
download_model_from_hf( |
|
|
repo_id=settings.HUGGINGFACE_REPO, |
|
|
model_filename=f"{settings.MODEL_NAME}.pt", |
|
|
local_dir=str(Path(settings.MODEL_PATH).parent), |
|
|
force_download=args.force |
|
|
) |
|
|
|
|
|
if args.verify: |
|
|
is_valid = verify_model_file(settings.MODEL_PATH) |
|
|
print(f"Model valid: {is_valid}") |
|
|
|
|
|
if args.info: |
|
|
info = get_model_info(settings.MODEL_PATH) |
|
|
print(f"Model info: {info}") |
|
|
|
|
|
if args.list_files: |
|
|
files = list_available_files(settings.HUGGINGFACE_REPO) |
|
|
print(f"Available files in {settings.HUGGINGFACE_REPO}:") |
|
|
for file in files: |
|
|
print(f" - {file}") |
|
|
|
|
|
if args.cleanup_cache: |
|
|
cleanup_model_cache() |
|
|
|