| import os |
| import json |
| import shutil |
| from pathlib import Path |
| from typing import Dict, List, Optional, Callable |
| from datetime import datetime |
| import requests |
| from huggingface_hub import snapshot_download, hf_hub_download |
| import hashlib |
|
|
|
|
| class ModelManager: |
|
|
| def __init__(self, config): |
| self.config = config |
| self.models_dir = config.get_path("models_pretrained") |
| self.models_dir.mkdir(exist_ok=True, parents=True) |
|
|
| self.available_models = { |
| 'stable-audio-open-small': { |
| 'name': 'Stable Audio Open Small', |
| 'repo': 'stabilityai/stable-audio-open-small', |
| 'files': ['model.safetensors'], |
| 'size': '2.1 GB', |
| 'description': 'Fast generation, good quality, lower memory usage', |
| 'best_for': 'Beginners, quick experiments, limited GPU', |
| 'license': 'Stability AI License', |
| 'checksum': 'sha256:abc123...' |
| }, |
| 'stable-audio-open-1.0': { |
| 'name': 'Stable Audio Open 1.0', |
| 'repo': 'stabilityai/stable-audio-open-1.0', |
| 'files': ['model.safetensors'], |
| 'size': '8.2 GB', |
| 'description': 'Highest quality, more detailed audio', |
| 'best_for': 'Professional use, high-end GPUs', |
| 'license': 'Stability AI License', |
| 'checksum': 'sha256:def456...' |
| } |
| } |
|
|
| self.terms_file = Path("config/terms_accepted.json") |
| self.terms_file.parent.mkdir(exist_ok=True) |
|
|
| def get_available_models(self) -> List[Dict]: |
|
|
| models = [] |
|
|
| for model_id, info in self.available_models.items(): |
| is_downloaded = self.is_model_downloaded(model_id) |
|
|
| downloaded_size = None |
| if is_downloaded: |
| if model_id == 'stable-audio-open-small': |
| model_file = self.models_dir / 'stable-audio-open-small-model.safetensors' |
| downloaded_size = self._get_file_size( |
| model_file) if model_file.exists() else None |
| elif model_id == 'stable-audio-open-1.0': |
| model_file = self.models_dir / 'stable-audio-open-model.safetensors' |
| downloaded_size = self._get_file_size( |
| model_file) if model_file.exists() else None |
| else: |
| model_path = self.models_dir / model_id |
| downloaded_size = self._get_downloaded_size( |
| model_path) if model_path.exists() else None |
|
|
| models.append({ |
| 'id': model_id, |
| 'name': info['name'], |
| 'size': info['size'], |
| 'description': info['description'], |
| 'best_for': info['best_for'], |
| 'license': info['license'], |
| 'downloaded': is_downloaded, |
| 'downloaded_size': downloaded_size, |
| 'terms_accepted': self.is_terms_accepted(model_id) |
| }) |
|
|
| return models |
|
|
| def _get_file_size(self, file_path: Path) -> str: |
|
|
| if not file_path.exists() or not file_path.is_file(): |
| return "0 B" |
|
|
| size = file_path.stat().st_size |
| return self._bytes_to_human(size) |
|
|
| def _get_downloaded_size(self, model_path: Path) -> str: |
|
|
| if not model_path.exists(): |
| return "0 B" |
|
|
| total_size = 0 |
| for file_path in model_path.rglob("*"): |
| if file_path.is_file(): |
| total_size += file_path.stat().st_size |
|
|
| for unit in ['B', 'KB', 'MB', 'GB']: |
| if total_size < 1024.0: |
| return f"{total_size:.1f} {unit}" |
| total_size /= 1024.0 |
| return f"{total_size:.1f} TB" |
|
|
| def get_model_info(self, model_id: str) -> Optional[Dict]: |
|
|
| if model_id not in self.available_models: |
| return None |
|
|
| info = self.available_models[model_id].copy() |
| info['id'] = model_id |
| info['downloaded'] = self.is_model_downloaded(model_id) |
| info['terms_accepted'] = self.is_terms_accepted(model_id) |
|
|
| return info |
|
|
| def is_model_downloaded(self, model_id: str) -> bool: |
|
|
| if model_id == 'stable-audio-open-small': |
| model_file = self.models_dir / 'stable-audio-open-small-model.safetensors' |
| return model_file.exists() and model_file.is_file() |
| elif model_id == 'stable-audio-open-1.0': |
| model_file = self.models_dir / 'stable-audio-open-model.safetensors' |
| return model_file.exists() and model_file.is_file() |
| else: |
| model_path = self.models_dir / model_id |
| if model_path.exists() and model_path.is_dir(): |
| return any(model_path.iterdir()) |
| pattern = f"*{model_id}*.safetensors" |
| matching_files = list(self.models_dir.glob(pattern)) |
| return len(matching_files) > 0 |
|
|
| def is_terms_accepted(self, model_id: str) -> bool: |
|
|
| if not self.terms_file.exists(): |
| return False |
|
|
| try: |
| with open(self.terms_file, 'r') as f: |
| terms_data = json.load(f) |
| return terms_data.get(model_id, {}).get('accepted', False) |
| except: |
| return False |
|
|
| def accept_terms(self, model_id: str) -> bool: |
|
|
| if model_id not in self.available_models: |
| return False |
|
|
| terms_data = {} |
| if self.terms_file.exists(): |
| try: |
| with open(self.terms_file, 'r') as f: |
| terms_data = json.load(f) |
| except: |
| terms_data = {} |
|
|
| terms_data[model_id] = { |
| 'accepted': True, |
| 'accepted_at': datetime.now().isoformat(), |
| 'model_name': self.available_models[model_id]['name'], |
| 'license': self.available_models[model_id]['license'] |
| } |
|
|
| try: |
| with open(self.terms_file, 'w') as f: |
| json.dump(terms_data, f, indent=2) |
| return True |
| except Exception as e: |
| print(f"Error saving terms acceptance: {e}") |
| return False |
|
|
| def download_model(self, model_id: str, progress_callback: Optional[Callable] = None) -> bool: |
|
|
| if model_id not in self.available_models: |
| return False |
|
|
| if not self.is_terms_accepted(model_id): |
| print(f"Terms not accepted for {model_id}") |
| self.accept_terms(model_id) |
| print(f"Automatically accepted terms for {model_id}") |
|
|
| model_info = self.available_models[model_id] |
| target_dir = self.models_dir |
| target_dir.mkdir(exist_ok=True, parents=True) |
|
|
| try: |
| print(f"Downloading {model_info['name']} to {target_dir}") |
|
|
| if progress_callback: |
| progress_callback( |
| 0, f"Starting download of {model_info['name']}...") |
|
|
| from huggingface_hub import HfApi |
| api = HfApi() |
|
|
| try: |
| user = api.whoami() |
| print(f"Authenticated as: {user}") |
| if progress_callback: |
| progress_callback(10, "Authentication verified...") |
| except Exception as auth_error: |
| print(f"Not authenticated with Hugging Face: {auth_error}") |
| if progress_callback: |
| progress_callback(0, "Authentication required...") |
|
|
| is_docker = os.environ.get('FRAGMENTA_DOCKER', '').strip() == '1' |
| if is_docker: |
| print("Docker mode: HF authentication required. " |
| "Set your token via Model Setup in the browser UI, " |
| "or pass -e HF_TOKEN=hf_xxx to docker run.") |
| if progress_callback: |
| progress_callback(0, "HF authentication required — use Model Setup to set your token") |
| return False |
|
|
| try: |
| from app.core.hf_auth_dialog import show_hf_auth_dialog |
| success = show_hf_auth_dialog() |
|
|
| if not success: |
| print("Authentication dialog was cancelled") |
| if progress_callback: |
| progress_callback(0, "Authentication cancelled") |
| return False |
|
|
| try: |
| user = api.whoami() |
| print(f"Now authenticated as: {user}") |
| if progress_callback: |
| progress_callback( |
| 10, "Authentication successful...") |
| except Exception as retry_error: |
| print(f"Still not authenticated: {retry_error}") |
| if progress_callback: |
| progress_callback(0, "Authentication failed") |
| return False |
|
|
| except ImportError: |
| print("To download models, you need to:") |
| print( |
| "1. Visit https://huggingface.co/stabilityai/stable-audio-open-small") |
| print("2. Accept the terms and conditions") |
| print("3. Log in to your Hugging Face account") |
| print( |
| "4. Get your access token from https://huggingface.co/settings/tokens") |
| print("5. Run: huggingface-cli login") |
| if progress_callback: |
| progress_callback(0, "Manual authentication required") |
| return False |
|
|
| if progress_callback: |
| progress_callback(20, "Starting file download...") |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| import shutil |
| from tqdm import tqdm |
| import sys |
| |
| class TqdmToCallback: |
| def __init__(self, callback, file_index, total_files): |
| self.callback = callback |
| self.file_index = file_index |
| self.total_files = total_files |
| self.last_percent = 0 |
| |
| def __call__(self, t): |
| """Returns a callback function for tqdm""" |
| def inner(bytes_amount=1): |
| if t.total: |
| file_progress = (t.n / t.total) |
| overall_progress = (self.file_index + file_progress) / self.total_files |
| percent = 20 + int(overall_progress * 70) |
| |
| if percent != self.last_percent: |
| self.last_percent = percent |
| downloaded_mb = t.n / (1024 * 1024) |
| total_mb = t.total / (1024 * 1024) |
| if self.callback: |
| self.callback( |
| percent, |
| f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB" |
| ) |
| return inner |
|
|
| downloaded_files = [] |
| total_files = len(model_info['files']) |
|
|
| for i, file_pattern in enumerate(model_info['files']): |
| if progress_callback: |
| progress_callback( |
| 20 + int((i / total_files) * 70), |
| f"Starting download of {file_pattern}..." |
| ) |
|
|
| try: |
| if file_pattern == 'model.safetensors': |
| if model_id == 'stable-audio-open-small': |
| final_filename = 'stable-audio-open-small-model.safetensors' |
| elif model_id == 'stable-audio-open-1.0': |
| final_filename = 'stable-audio-open-model.safetensors' |
| else: |
| final_filename = f"{model_id}-model.safetensors" |
| else: |
| final_filename = f"{model_id}-{file_pattern}" |
|
|
| tqdm_callback = TqdmToCallback(progress_callback, i, total_files) |
| |
| original_tqdm_init = tqdm.__init__ |
| |
| def patched_tqdm_init(self, *args, **kwargs): |
| original_tqdm_init(self, *args, **kwargs) |
| |
| original_update = self.update |
| def new_update(n=1): |
| result = original_update(n) |
| if progress_callback and self.total: |
| file_progress = (self.n / self.total) |
| overall_progress = (i + file_progress) / total_files |
| percent = 20 + int(overall_progress * 70) |
| downloaded_mb = self.n / (1024 * 1024) |
| total_mb = self.total / (1024 * 1024) |
| progress_callback( |
| percent, |
| f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB" |
| ) |
| return result |
| self.update = new_update |
| |
| tqdm.__init__ = patched_tqdm_init |
|
|
| try: |
| downloaded_file = hf_hub_download( |
| repo_id=model_info['repo'], |
| filename=file_pattern, |
| resume_download=True |
| ) |
| finally: |
| tqdm.__init__ = original_tqdm_init |
|
|
| downloaded_path = Path(downloaded_file) |
| final_path = target_dir / final_filename |
|
|
| final_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| shutil.copy2(str(downloaded_path), str(final_path)) |
| print(f"Saved as {final_filename}") |
|
|
| downloaded_files.append(str(final_path)) |
|
|
| if progress_callback: |
| progress_callback( |
| 20 + int(((i + 1) / total_files) * 70), |
| f"Completed {file_pattern}" |
| ) |
|
|
| except Exception as file_error: |
| print( |
| f"Failed to download {file_pattern}: {file_error}") |
| if progress_callback: |
| progress_callback( |
| 0, f"Failed to download {file_pattern}") |
| continue |
|
|
| print(f"Downloaded {len(downloaded_files)} files") |
|
|
| if progress_callback: |
| progress_callback( |
| 95, "Download completed, verifying files...") |
|
|
| except Exception as download_error: |
| print(f"Error during download: {download_error}") |
| if progress_callback: |
| progress_callback( |
| 0, f"Download failed: {str(download_error)}") |
| return False |
|
|
| if progress_callback: |
| progress_callback(95, "Verifying download...") |
|
|
| expected_files = [] |
| if model_id == 'stable-audio-open-small': |
| expected_files.append( |
| 'stable-audio-open-small-model.safetensors') |
| elif model_id == 'stable-audio-open-1.0': |
| expected_files.append('stable-audio-open-model.safetensors') |
| else: |
| expected_files.append(f"{model_id}-model.safetensors") |
|
|
| files_exist = any((target_dir / expected_file).exists() |
| for expected_file in expected_files) |
|
|
| if files_exist: |
| if progress_callback: |
| progress_callback(100, "Download complete!") |
| print(f"Successfully downloaded {model_info['name']}") |
| return True |
| else: |
| if progress_callback: |
| progress_callback(0, "Download verification failed") |
| print(f"Expected files not found: {expected_files}") |
| return False |
|
|
| except Exception as e: |
| print(f"Error downloading {model_info['name']}: {e}") |
| if progress_callback: |
| progress_callback(0, f"Error: {str(e)}") |
|
|
| if "403" in str(e) and "gated repositories" in str(e).lower(): |
| print("Token permission issue detected!") |
| print( |
| "Your Hugging Face token needs 'Read access to public gated repositories'") |
| print("Please:") |
| print("1. Go to https://huggingface.co/settings/tokens") |
| print("2. Edit your token or create a new one") |
| print("3. Enable 'Read access to public gated repositories'") |
| print("4. Try the download again") |
| elif "401" in str(e) or "restricted" in str(e).lower(): |
| print("This model requires Hugging Face authentication.") |
| print("Please visit the model page and accept terms first:") |
| print(f"https://huggingface.co/{model_info['repo']}") |
| return False |
|
|
| def delete_model(self, model_id: str) -> bool: |
|
|
| deleted_something = False |
|
|
| if model_id == 'stable-audio-open-small': |
| model_file = self.models_dir / 'stable-audio-open-small-model.safetensors' |
| config_file = self.models_dir / 'stable-audio-open-small-config.json' |
| elif model_id == 'stable-audio-open-1.0': |
| model_file = self.models_dir / 'stable-audio-open-model.safetensors' |
| config_file = self.models_dir / 'stable-audio-open-1.0-config.json' |
| else: |
| model_file = self.models_dir / f"{model_id}-model.safetensors" |
| config_file = self.models_dir / f"{model_id}-config.json" |
|
|
| for file_path in [model_file, config_file]: |
| if file_path.exists(): |
| try: |
| file_path.unlink() |
| print(f"Deleted {file_path.name}") |
| deleted_something = True |
| except Exception as e: |
| print(f"Error deleting {file_path.name}: {e}") |
|
|
| model_path = self.models_dir / model_id |
| if model_path.exists() and model_path.is_dir(): |
| try: |
| shutil.rmtree(model_path) |
| print(f"Deleted {model_id} directory") |
| deleted_something = True |
| except Exception as e: |
| print(f"Error deleting {model_id} directory: {e}") |
|
|
| if deleted_something: |
| print(f"Deleted {model_id}") |
| return True |
| else: |
| print(f"No files found for {model_id}") |
| return False |
|
|
| def get_download_progress(self, model_id: str) -> Dict: |
|
|
| return { |
| 'model_id': model_id, |
| 'downloaded': self.is_model_downloaded(model_id), |
| 'size': self.available_models.get(model_id, {}).get('size', 'Unknown') |
| } |
|
|
| def get_storage_info(self) -> Dict: |
|
|
| total_size = 0 |
| model_count = 0 |
|
|
| if self.models_dir.exists(): |
| for model_id in self.available_models.keys(): |
| if self.is_model_downloaded(model_id): |
| model_count += 1 |
|
|
| for file_path in self.models_dir.rglob("*"): |
| if file_path.is_file(): |
| total_size += file_path.stat().st_size |
|
|
| return { |
| 'total_size_bytes': total_size, |
| 'total_size_human': self._bytes_to_human(total_size), |
| 'model_count': model_count, |
| 'models_dir': str(self.models_dir) |
| } |
|
|
| def _bytes_to_human(self, bytes_value: int) -> str: |
|
|
| for unit in ['B', 'KB', 'MB', 'GB']: |
| if bytes_value < 1024.0: |
| return f"{bytes_value:.1f} {unit}" |
| bytes_value /= 1024.0 |
| return f"{bytes_value:.1f} TB" |
|
|