import os import json from pathlib import Path from typing import Optional, Dict, Any from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download class HFModelManager: """Utility class for managing model checkpoints on Hugging Face Hub.""" def __init__(self, token: Optional[str] = None, username: Optional[str] = None): self.token = token or os.getenv("HF_TOKEN") self.username = username or os.getenv("HF_USERNAME") if not self.token: raise ValueError("HF_TOKEN environment variable must be set") if not self.username: raise ValueError("HF_USERNAME environment variable must be set") # Set up authentication try: from huggingface_hub import login login(token=self.token, write_permission=True) print("✅ Hugging Face authentication successful") except Exception as e: print(f"⚠️ Hugging Face authentication failed: {e}") raise self.api = HfApi(token=self.token) def create_model_repo(self, model_name: str, private: bool = False) -> str: """Create a new model repository.""" repo_id = f"{self.username}/{model_name}" try: create_repo( repo_id=repo_id, repo_type="model", private=private, exist_ok=True ) return repo_id except Exception as e: print(f"Failed to create repo {repo_id}: {e}") return repo_id def push_checkpoint( self, local_path: str, repo_id: str, commit_message: str = "Update model checkpoint" ) -> bool: """Push a local checkpoint to HF Hub.""" try: if not os.path.exists(local_path): print(f"Checkpoint not found: {local_path}") return False # Upload the checkpoint file upload_file( path_or_fileobj=local_path, path_in_repo=os.path.basename(local_path), repo_id=repo_id, repo_type="model", commit_message=commit_message ) print(f"Successfully pushed {local_path} to {repo_id}") return True except Exception as e: print(f"Failed to push checkpoint: {e}") return False def push_metrics( self, metrics: Dict[str, Any], repo_id: str, filename: str = "training_metrics.json" ) -> bool: """Push training metrics to HF Hub.""" try: # Create a temporary file temp_path = f"/tmp/{filename}" with open(temp_path, 'w') as f: json.dump(metrics, f, indent=2) # Upload metrics upload_file( path_or_fileobj=temp_path, path_in_repo=filename, repo_id=repo_id, repo_type="model", commit_message="Update training metrics" ) # Clean up os.remove(temp_path) print(f"Successfully pushed metrics to {repo_id}") return True except Exception as e: print(f"Failed to push metrics: {e}") return False def download_checkpoint( self, repo_id: str, local_dir: str = "./models", filename: Optional[str] = None ) -> Optional[str]: """Download a checkpoint from HF Hub.""" try: os.makedirs(local_dir, exist_ok=True) if filename: # Download specific file local_path = os.path.join(local_dir, filename) snapshot_download( repo_id=repo_id, repo_type="model", local_dir=local_dir, allow_patterns=[filename] ) return local_path if os.path.exists(local_path) else None else: # Download entire repo snapshot_download( repo_id=repo_id, repo_type="model", local_dir=local_dir ) return local_dir except Exception as e: print(f"Failed to download checkpoint: {e}") return None def list_repo_files(self, repo_id: str) -> list: """List all files in a repository.""" try: repo_info = self.api.model_info(repo_id) return [f.rfilename for f in repo_info.siblings] except Exception as e: print(f"Failed to list repo files: {e}") return [] def upload_model(self, model_type: str, repo_name: str) -> Dict[str, Any]: """Upload models or data to HF Hub based on type.""" try: if model_type == "models": # Upload model checkpoints repo_id = f"{self.username}/{repo_name}" self.create_model_repo(repo_name, private=False) # Upload best model checkpoints model_files = [ "models/exports/resnet_item_embedder_best.pth", "models/exports/vit_outfit_model_best.pth", "models/exports/resnet_metrics.json", "models/exports/vit_metrics.json" ] uploaded_files = [] for file_path in model_files: if os.path.exists(file_path): success = self.push_checkpoint(file_path, repo_id, f"Upload {os.path.basename(file_path)}") if success: uploaded_files.append(os.path.basename(file_path)) return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id} elif model_type == "splits": # Upload dataset splits repo_id = f"{self.username}/{repo_name}" try: create_repo( repo_id=repo_id, repo_type="dataset", private=False, exist_ok=True ) except Exception as e: print(f"Note: Repo might already exist: {e}") # Upload split files split_files = [ "data/Polyvore/splits/train.json", "data/Polyvore/splits/valid.json", "data/Polyvore/splits/test.json", "data/Polyvore/splits/outfit_triplets_train.json", "data/Polyvore/splits/outfit_triplets_valid.json", "data/Polyvore/splits/outfit_triplets_test.json" ] uploaded_files = [] for file_path in split_files: if os.path.exists(file_path): try: upload_file( path_or_fileobj=file_path, path_in_repo=f"splits/{os.path.basename(file_path)}", repo_id=repo_id, repo_type="dataset", commit_message=f"Upload {os.path.basename(file_path)}" ) uploaded_files.append(os.path.basename(file_path)) except Exception as e: print(f"Failed to upload {file_path}: {e}") return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id} elif model_type == "everything": # Upload everything models_result = self.upload_model("models", "dressify-models") splits_result = self.upload_model("splits", "Dressify-Helper") return { "success": models_result["success"] and splits_result["success"], "models": models_result, "splits": splits_result } else: return {"success": False, "error": f"Unknown model type: {model_type}"} except Exception as e: return {"success": False, "error": str(e)} def push_model_to_hub( checkpoint_path: str, model_name: str, token: Optional[str] = None, username: Optional[str] = None, private: bool = False ) -> bool: """Convenience function to push a model checkpoint to HF Hub.""" manager = HFModelManager(token=token, username=username) repo_id = manager.create_model_repo(model_name, private=private) return manager.push_checkpoint(checkpoint_path, repo_id) def download_model_from_hub( repo_id: str, local_dir: str = "./models", filename: Optional[str] = None ) -> Optional[str]: """Convenience function to download a model from HF Hub.""" manager = HFModelManager() return manager.download_checkpoint(repo_id, local_dir, filename) if __name__ == "__main__": # Example usage import argparse parser = argparse.ArgumentParser(description="HF Hub model management") parser.add_argument("--action", choices=["push", "download"], required=True) parser.add_argument("--checkpoint", type=str, help="Local checkpoint path") parser.add_argument("--repo", type=str, help="Repository ID") parser.add_argument("--model-name", type=str, help="Model name for new repo") parser.add_argument("--local-dir", type=str, default="./models", help="Local directory") args = parser.parse_args() if args.action == "push": if not args.checkpoint or not args.model_name: print("--checkpoint and --model-name required for push") exit(1) success = push_model_to_hub(args.checkpoint, args.model_name) print(f"Push {'successful' if success else 'failed'}") elif args.action == "download": if not args.repo: print("--repo required for download") exit(1) result = download_model_from_hub(args.repo, args.local_dir) if result: print(f"Downloaded to: {result}") else: print("Download failed")