fishapi / app /utils /model_utils.py
kamau1's picture
Initial commit
bcc2f7b verified
"""
Model utilities for downloading and managing the marine species model.
"""
import os
import shutil
from pathlib import Path
from typing import Optional, Dict, Any
from huggingface_hub import hf_hub_download, list_repo_files
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
def download_model_from_hf(
repo_id: str,
model_filename: str,
local_dir: str,
force_download: bool = False
) -> str:
"""
Download model from HuggingFace Hub.
Args:
repo_id: HuggingFace repository ID
model_filename: Name of the model file
local_dir: Local directory to save the model
force_download: Whether to force re-download if file exists
Returns:
Path to the downloaded model file
"""
try:
# Create local directory if it doesn't exist
Path(local_dir).mkdir(parents=True, exist_ok=True)
local_path = Path(local_dir) / model_filename
# Check if file already exists and force_download is False
if local_path.exists() and not force_download:
logger.info(f"Model already exists at {local_path}")
return str(local_path)
logger.info(f"Downloading {model_filename} from {repo_id}...")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=model_filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
force_download=force_download
)
logger.info(f"Model downloaded successfully to: {downloaded_path}")
return downloaded_path
except Exception as e:
logger.error(f"Failed to download model: {str(e)}")
raise
def list_available_files(repo_id: str) -> list:
"""
List all available files in a HuggingFace repository.
Args:
repo_id: HuggingFace repository ID
Returns:
List of available files
"""
try:
files = list_repo_files(repo_id)
return files
except Exception as e:
logger.error(f"Failed to list repository files: {str(e)}")
return []
def verify_model_file(model_path: str) -> bool:
"""
Verify that a model file exists and is valid.
Args:
model_path: Path to the model file
Returns:
True if model file is valid
"""
try:
path = Path(model_path)
# Check if file exists
if not path.exists():
logger.error(f"Model file does not exist: {model_path}")
return False
# Check file size (should be > 1MB for a real model)
file_size = path.stat().st_size
if file_size < 1024 * 1024: # 1MB
logger.warning(f"Model file seems too small: {file_size} bytes")
return False
# Check file extension
if not path.suffix.lower() in ['.pt', '.pth']:
logger.warning(f"Unexpected model file extension: {path.suffix}")
logger.info(f"Model file verified: {model_path} ({file_size / (1024*1024):.1f} MB)")
return True
except Exception as e:
logger.error(f"Failed to verify model file: {str(e)}")
return False
def get_model_info(model_path: str) -> Dict[str, Any]:
"""
Get information about a model file.
Args:
model_path: Path to the model file
Returns:
Dictionary with model information
"""
info = {
"path": model_path,
"exists": False,
"size_mb": 0,
"size_bytes": 0
}
try:
path = Path(model_path)
if path.exists():
info["exists"] = True
size_bytes = path.stat().st_size
info["size_bytes"] = size_bytes
info["size_mb"] = size_bytes / (1024 * 1024)
info["modified_time"] = path.stat().st_mtime
except Exception as e:
logger.error(f"Failed to get model info: {str(e)}")
return info
def cleanup_model_cache(cache_dir: Optional[str] = None) -> None:
"""
Clean up model cache directory.
Args:
cache_dir: Cache directory to clean (uses default if None)
"""
try:
if cache_dir is None:
cache_dir = Path.home() / ".cache" / "huggingface"
cache_path = Path(cache_dir)
if cache_path.exists():
logger.info(f"Cleaning up cache directory: {cache_path}")
shutil.rmtree(cache_path)
logger.info("Cache cleanup completed")
else:
logger.info("Cache directory does not exist")
except Exception as e:
logger.error(f"Failed to cleanup cache: {str(e)}")
def setup_model_directory() -> str:
"""
Setup the model directory and ensure it exists.
Returns:
Path to the model directory
"""
model_dir = Path(settings.MODEL_PATH).parent
model_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Model directory setup: {model_dir}")
return str(model_dir)
if __name__ == "__main__":
# Command line utility for model management
import argparse
parser = argparse.ArgumentParser(description="Model management utility")
parser.add_argument("--download", action="store_true", help="Download model from HuggingFace")
parser.add_argument("--verify", action="store_true", help="Verify model file")
parser.add_argument("--info", action="store_true", help="Show model information")
parser.add_argument("--list-files", action="store_true", help="List available files in HF repo")
parser.add_argument("--cleanup-cache", action="store_true", help="Cleanup model cache")
parser.add_argument("--force", action="store_true", help="Force download even if file exists")
args = parser.parse_args()
if args.download:
setup_model_directory()
download_model_from_hf(
repo_id=settings.HUGGINGFACE_REPO,
model_filename=f"{settings.MODEL_NAME}.pt",
local_dir=str(Path(settings.MODEL_PATH).parent),
force_download=args.force
)
if args.verify:
is_valid = verify_model_file(settings.MODEL_PATH)
print(f"Model valid: {is_valid}")
if args.info:
info = get_model_info(settings.MODEL_PATH)
print(f"Model info: {info}")
if args.list_files:
files = list_available_files(settings.HUGGINGFACE_REPO)
print(f"Available files in {settings.HUGGINGFACE_REPO}:")
for file in files:
print(f" - {file}")
if args.cleanup_cache:
cleanup_model_cache()