recomendation / utils /hf_utils.py
Ali Mohsin
Fixed 1 million more errors
227af5e
import os
import json
from pathlib import Path
from typing import Optional, Dict, Any
from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download
class HFModelManager:
"""Utility class for managing model checkpoints on Hugging Face Hub."""
def __init__(self, token: Optional[str] = None, username: Optional[str] = None):
self.token = token or os.getenv("HF_TOKEN")
self.username = username or os.getenv("HF_USERNAME")
if not self.token:
raise ValueError("HF_TOKEN environment variable must be set")
if not self.username:
raise ValueError("HF_USERNAME environment variable must be set")
# Set up authentication
try:
from huggingface_hub import login
login(token=self.token, write_permission=True)
print("✅ Hugging Face authentication successful")
except Exception as e:
print(f"⚠️ Hugging Face authentication failed: {e}")
raise
self.api = HfApi(token=self.token)
def create_model_repo(self, model_name: str, private: bool = False) -> str:
"""Create a new model repository."""
repo_id = f"{self.username}/{model_name}"
try:
create_repo(
repo_id=repo_id,
repo_type="model",
private=private,
exist_ok=True
)
return repo_id
except Exception as e:
print(f"Failed to create repo {repo_id}: {e}")
return repo_id
def push_checkpoint(
self,
local_path: str,
repo_id: str,
commit_message: str = "Update model checkpoint"
) -> bool:
"""Push a local checkpoint to HF Hub."""
try:
if not os.path.exists(local_path):
print(f"Checkpoint not found: {local_path}")
return False
# Upload the checkpoint file
upload_file(
path_or_fileobj=local_path,
path_in_repo=os.path.basename(local_path),
repo_id=repo_id,
repo_type="model",
commit_message=commit_message
)
print(f"Successfully pushed {local_path} to {repo_id}")
return True
except Exception as e:
print(f"Failed to push checkpoint: {e}")
return False
def push_metrics(
self,
metrics: Dict[str, Any],
repo_id: str,
filename: str = "training_metrics.json"
) -> bool:
"""Push training metrics to HF Hub."""
try:
# Create a temporary file
temp_path = f"/tmp/{filename}"
with open(temp_path, 'w') as f:
json.dump(metrics, f, indent=2)
# Upload metrics
upload_file(
path_or_fileobj=temp_path,
path_in_repo=filename,
repo_id=repo_id,
repo_type="model",
commit_message="Update training metrics"
)
# Clean up
os.remove(temp_path)
print(f"Successfully pushed metrics to {repo_id}")
return True
except Exception as e:
print(f"Failed to push metrics: {e}")
return False
def download_checkpoint(
self,
repo_id: str,
local_dir: str = "./models",
filename: Optional[str] = None
) -> Optional[str]:
"""Download a checkpoint from HF Hub."""
try:
os.makedirs(local_dir, exist_ok=True)
if filename:
# Download specific file
local_path = os.path.join(local_dir, filename)
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=local_dir,
allow_patterns=[filename]
)
return local_path if os.path.exists(local_path) else None
else:
# Download entire repo
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=local_dir
)
return local_dir
except Exception as e:
print(f"Failed to download checkpoint: {e}")
return None
def list_repo_files(self, repo_id: str) -> list:
"""List all files in a repository."""
try:
repo_info = self.api.model_info(repo_id)
return [f.rfilename for f in repo_info.siblings]
except Exception as e:
print(f"Failed to list repo files: {e}")
return []
def upload_model(self, model_type: str, repo_name: str) -> Dict[str, Any]:
"""Upload models or data to HF Hub based on type."""
try:
if model_type == "models":
# Upload model checkpoints
repo_id = f"{self.username}/{repo_name}"
self.create_model_repo(repo_name, private=False)
# Upload best model checkpoints
model_files = [
"models/exports/resnet_item_embedder_best.pth",
"models/exports/vit_outfit_model_best.pth",
"models/exports/resnet_metrics.json",
"models/exports/vit_metrics.json"
]
uploaded_files = []
for file_path in model_files:
if os.path.exists(file_path):
success = self.push_checkpoint(file_path, repo_id, f"Upload {os.path.basename(file_path)}")
if success:
uploaded_files.append(os.path.basename(file_path))
return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id}
elif model_type == "splits":
# Upload dataset splits
repo_id = f"{self.username}/{repo_name}"
try:
create_repo(
repo_id=repo_id,
repo_type="dataset",
private=False,
exist_ok=True
)
except Exception as e:
print(f"Note: Repo might already exist: {e}")
# Upload split files
split_files = [
"data/Polyvore/splits/train.json",
"data/Polyvore/splits/valid.json",
"data/Polyvore/splits/test.json",
"data/Polyvore/splits/outfit_triplets_train.json",
"data/Polyvore/splits/outfit_triplets_valid.json",
"data/Polyvore/splits/outfit_triplets_test.json"
]
uploaded_files = []
for file_path in split_files:
if os.path.exists(file_path):
try:
upload_file(
path_or_fileobj=file_path,
path_in_repo=f"splits/{os.path.basename(file_path)}",
repo_id=repo_id,
repo_type="dataset",
commit_message=f"Upload {os.path.basename(file_path)}"
)
uploaded_files.append(os.path.basename(file_path))
except Exception as e:
print(f"Failed to upload {file_path}: {e}")
return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id}
elif model_type == "everything":
# Upload everything
models_result = self.upload_model("models", "dressify-models")
splits_result = self.upload_model("splits", "Dressify-Helper")
return {
"success": models_result["success"] and splits_result["success"],
"models": models_result,
"splits": splits_result
}
else:
return {"success": False, "error": f"Unknown model type: {model_type}"}
except Exception as e:
return {"success": False, "error": str(e)}
def push_model_to_hub(
checkpoint_path: str,
model_name: str,
token: Optional[str] = None,
username: Optional[str] = None,
private: bool = False
) -> bool:
"""Convenience function to push a model checkpoint to HF Hub."""
manager = HFModelManager(token=token, username=username)
repo_id = manager.create_model_repo(model_name, private=private)
return manager.push_checkpoint(checkpoint_path, repo_id)
def download_model_from_hub(
repo_id: str,
local_dir: str = "./models",
filename: Optional[str] = None
) -> Optional[str]:
"""Convenience function to download a model from HF Hub."""
manager = HFModelManager()
return manager.download_checkpoint(repo_id, local_dir, filename)
if __name__ == "__main__":
# Example usage
import argparse
parser = argparse.ArgumentParser(description="HF Hub model management")
parser.add_argument("--action", choices=["push", "download"], required=True)
parser.add_argument("--checkpoint", type=str, help="Local checkpoint path")
parser.add_argument("--repo", type=str, help="Repository ID")
parser.add_argument("--model-name", type=str, help="Model name for new repo")
parser.add_argument("--local-dir", type=str, default="./models", help="Local directory")
args = parser.parse_args()
if args.action == "push":
if not args.checkpoint or not args.model_name:
print("--checkpoint and --model-name required for push")
exit(1)
success = push_model_to_hub(args.checkpoint, args.model_name)
print(f"Push {'successful' if success else 'failed'}")
elif args.action == "download":
if not args.repo:
print("--repo required for download")
exit(1)
result = download_model_from_hub(args.repo, args.local_dir)
if result:
print(f"Downloaded to: {result}")
else:
print("Download failed")