Spaces:
Paused
Paused
| import os | |
| import json | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download | |
| class HFModelManager: | |
| """Utility class for managing model checkpoints on Hugging Face Hub.""" | |
| def __init__(self, token: Optional[str] = None, username: Optional[str] = None): | |
| self.token = token or os.getenv("HF_TOKEN") | |
| self.username = username or os.getenv("HF_USERNAME") | |
| if not self.token: | |
| raise ValueError("HF_TOKEN environment variable must be set") | |
| if not self.username: | |
| raise ValueError("HF_USERNAME environment variable must be set") | |
| # Set up authentication | |
| try: | |
| from huggingface_hub import login | |
| login(token=self.token, write_permission=True) | |
| print("✅ Hugging Face authentication successful") | |
| except Exception as e: | |
| print(f"⚠️ Hugging Face authentication failed: {e}") | |
| raise | |
| self.api = HfApi(token=self.token) | |
| def create_model_repo(self, model_name: str, private: bool = False) -> str: | |
| """Create a new model repository.""" | |
| repo_id = f"{self.username}/{model_name}" | |
| try: | |
| create_repo( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| private=private, | |
| exist_ok=True | |
| ) | |
| return repo_id | |
| except Exception as e: | |
| print(f"Failed to create repo {repo_id}: {e}") | |
| return repo_id | |
| def push_checkpoint( | |
| self, | |
| local_path: str, | |
| repo_id: str, | |
| commit_message: str = "Update model checkpoint" | |
| ) -> bool: | |
| """Push a local checkpoint to HF Hub.""" | |
| try: | |
| if not os.path.exists(local_path): | |
| print(f"Checkpoint not found: {local_path}") | |
| return False | |
| # Upload the checkpoint file | |
| upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=os.path.basename(local_path), | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message=commit_message | |
| ) | |
| print(f"Successfully pushed {local_path} to {repo_id}") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to push checkpoint: {e}") | |
| return False | |
| def push_metrics( | |
| self, | |
| metrics: Dict[str, Any], | |
| repo_id: str, | |
| filename: str = "training_metrics.json" | |
| ) -> bool: | |
| """Push training metrics to HF Hub.""" | |
| try: | |
| # Create a temporary file | |
| temp_path = f"/tmp/{filename}" | |
| with open(temp_path, 'w') as f: | |
| json.dump(metrics, f, indent=2) | |
| # Upload metrics | |
| upload_file( | |
| path_or_fileobj=temp_path, | |
| path_in_repo=filename, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message="Update training metrics" | |
| ) | |
| # Clean up | |
| os.remove(temp_path) | |
| print(f"Successfully pushed metrics to {repo_id}") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to push metrics: {e}") | |
| return False | |
| def download_checkpoint( | |
| self, | |
| repo_id: str, | |
| local_dir: str = "./models", | |
| filename: Optional[str] = None | |
| ) -> Optional[str]: | |
| """Download a checkpoint from HF Hub.""" | |
| try: | |
| os.makedirs(local_dir, exist_ok=True) | |
| if filename: | |
| # Download specific file | |
| local_path = os.path.join(local_dir, filename) | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| local_dir=local_dir, | |
| allow_patterns=[filename] | |
| ) | |
| return local_path if os.path.exists(local_path) else None | |
| else: | |
| # Download entire repo | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| local_dir=local_dir | |
| ) | |
| return local_dir | |
| except Exception as e: | |
| print(f"Failed to download checkpoint: {e}") | |
| return None | |
| def list_repo_files(self, repo_id: str) -> list: | |
| """List all files in a repository.""" | |
| try: | |
| repo_info = self.api.model_info(repo_id) | |
| return [f.rfilename for f in repo_info.siblings] | |
| except Exception as e: | |
| print(f"Failed to list repo files: {e}") | |
| return [] | |
| def upload_model(self, model_type: str, repo_name: str) -> Dict[str, Any]: | |
| """Upload models or data to HF Hub based on type.""" | |
| try: | |
| if model_type == "models": | |
| # Upload model checkpoints | |
| repo_id = f"{self.username}/{repo_name}" | |
| self.create_model_repo(repo_name, private=False) | |
| # Upload best model checkpoints | |
| model_files = [ | |
| "models/exports/resnet_item_embedder_best.pth", | |
| "models/exports/vit_outfit_model_best.pth", | |
| "models/exports/resnet_metrics.json", | |
| "models/exports/vit_metrics.json" | |
| ] | |
| uploaded_files = [] | |
| for file_path in model_files: | |
| if os.path.exists(file_path): | |
| success = self.push_checkpoint(file_path, repo_id, f"Upload {os.path.basename(file_path)}") | |
| if success: | |
| uploaded_files.append(os.path.basename(file_path)) | |
| return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id} | |
| elif model_type == "splits": | |
| # Upload dataset splits | |
| repo_id = f"{self.username}/{repo_name}" | |
| try: | |
| create_repo( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| private=False, | |
| exist_ok=True | |
| ) | |
| except Exception as e: | |
| print(f"Note: Repo might already exist: {e}") | |
| # Upload split files | |
| split_files = [ | |
| "data/Polyvore/splits/train.json", | |
| "data/Polyvore/splits/valid.json", | |
| "data/Polyvore/splits/test.json", | |
| "data/Polyvore/splits/outfit_triplets_train.json", | |
| "data/Polyvore/splits/outfit_triplets_valid.json", | |
| "data/Polyvore/splits/outfit_triplets_test.json" | |
| ] | |
| uploaded_files = [] | |
| for file_path in split_files: | |
| if os.path.exists(file_path): | |
| try: | |
| upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=f"splits/{os.path.basename(file_path)}", | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| commit_message=f"Upload {os.path.basename(file_path)}" | |
| ) | |
| uploaded_files.append(os.path.basename(file_path)) | |
| except Exception as e: | |
| print(f"Failed to upload {file_path}: {e}") | |
| return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id} | |
| elif model_type == "everything": | |
| # Upload everything | |
| models_result = self.upload_model("models", "dressify-models") | |
| splits_result = self.upload_model("splits", "Dressify-Helper") | |
| return { | |
| "success": models_result["success"] and splits_result["success"], | |
| "models": models_result, | |
| "splits": splits_result | |
| } | |
| else: | |
| return {"success": False, "error": f"Unknown model type: {model_type}"} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def push_model_to_hub( | |
| checkpoint_path: str, | |
| model_name: str, | |
| token: Optional[str] = None, | |
| username: Optional[str] = None, | |
| private: bool = False | |
| ) -> bool: | |
| """Convenience function to push a model checkpoint to HF Hub.""" | |
| manager = HFModelManager(token=token, username=username) | |
| repo_id = manager.create_model_repo(model_name, private=private) | |
| return manager.push_checkpoint(checkpoint_path, repo_id) | |
| def download_model_from_hub( | |
| repo_id: str, | |
| local_dir: str = "./models", | |
| filename: Optional[str] = None | |
| ) -> Optional[str]: | |
| """Convenience function to download a model from HF Hub.""" | |
| manager = HFModelManager() | |
| return manager.download_checkpoint(repo_id, local_dir, filename) | |
| if __name__ == "__main__": | |
| # Example usage | |
| import argparse | |
| parser = argparse.ArgumentParser(description="HF Hub model management") | |
| parser.add_argument("--action", choices=["push", "download"], required=True) | |
| parser.add_argument("--checkpoint", type=str, help="Local checkpoint path") | |
| parser.add_argument("--repo", type=str, help="Repository ID") | |
| parser.add_argument("--model-name", type=str, help="Model name for new repo") | |
| parser.add_argument("--local-dir", type=str, default="./models", help="Local directory") | |
| args = parser.parse_args() | |
| if args.action == "push": | |
| if not args.checkpoint or not args.model_name: | |
| print("--checkpoint and --model-name required for push") | |
| exit(1) | |
| success = push_model_to_hub(args.checkpoint, args.model_name) | |
| print(f"Push {'successful' if success else 'failed'}") | |
| elif args.action == "download": | |
| if not args.repo: | |
| print("--repo required for download") | |
| exit(1) | |
| result = download_model_from_hub(args.repo, args.local_dir) | |
| if result: | |
| print(f"Downloaded to: {result}") | |
| else: | |
| print("Download failed") | |