fragmenta / app /core /model_manager.py
MazCodes's picture
Upload folder using huggingface_hub
63f0b06 verified
import os
import json
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Callable
from datetime import datetime
import requests
from huggingface_hub import snapshot_download, hf_hub_download
import hashlib
class ModelManager:
def __init__(self, config):
self.config = config
self.models_dir = config.get_path("models_pretrained")
self.models_dir.mkdir(exist_ok=True, parents=True)
self.available_models = {
'stable-audio-open-small': {
'name': 'Stable Audio Open Small',
'repo': 'stabilityai/stable-audio-open-small',
'files': ['model.safetensors'],
'size': '2.1 GB',
'description': 'Fast generation, good quality, lower memory usage',
'best_for': 'Beginners, quick experiments, limited GPU',
'license': 'Stability AI License',
'checksum': 'sha256:abc123...'
},
'stable-audio-open-1.0': {
'name': 'Stable Audio Open 1.0',
'repo': 'stabilityai/stable-audio-open-1.0',
'files': ['model.safetensors'],
'size': '8.2 GB',
'description': 'Highest quality, more detailed audio',
'best_for': 'Professional use, high-end GPUs',
'license': 'Stability AI License',
'checksum': 'sha256:def456...'
}
}
self.terms_file = Path("config/terms_accepted.json")
self.terms_file.parent.mkdir(exist_ok=True)
def get_available_models(self) -> List[Dict]:
models = []
for model_id, info in self.available_models.items():
is_downloaded = self.is_model_downloaded(model_id)
downloaded_size = None
if is_downloaded:
if model_id == 'stable-audio-open-small':
model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
downloaded_size = self._get_file_size(
model_file) if model_file.exists() else None
elif model_id == 'stable-audio-open-1.0':
model_file = self.models_dir / 'stable-audio-open-model.safetensors'
downloaded_size = self._get_file_size(
model_file) if model_file.exists() else None
else:
model_path = self.models_dir / model_id
downloaded_size = self._get_downloaded_size(
model_path) if model_path.exists() else None
models.append({
'id': model_id,
'name': info['name'],
'size': info['size'],
'description': info['description'],
'best_for': info['best_for'],
'license': info['license'],
'downloaded': is_downloaded,
'downloaded_size': downloaded_size,
'terms_accepted': self.is_terms_accepted(model_id)
})
return models
def _get_file_size(self, file_path: Path) -> str:
if not file_path.exists() or not file_path.is_file():
return "0 B"
size = file_path.stat().st_size
return self._bytes_to_human(size)
def _get_downloaded_size(self, model_path: Path) -> str:
if not model_path.exists():
return "0 B"
total_size = 0
for file_path in model_path.rglob("*"):
if file_path.is_file():
total_size += file_path.stat().st_size
for unit in ['B', 'KB', 'MB', 'GB']:
if total_size < 1024.0:
return f"{total_size:.1f} {unit}"
total_size /= 1024.0
return f"{total_size:.1f} TB"
def get_model_info(self, model_id: str) -> Optional[Dict]:
if model_id not in self.available_models:
return None
info = self.available_models[model_id].copy()
info['id'] = model_id
info['downloaded'] = self.is_model_downloaded(model_id)
info['terms_accepted'] = self.is_terms_accepted(model_id)
return info
def is_model_downloaded(self, model_id: str) -> bool:
if model_id == 'stable-audio-open-small':
model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
return model_file.exists() and model_file.is_file()
elif model_id == 'stable-audio-open-1.0':
model_file = self.models_dir / 'stable-audio-open-model.safetensors'
return model_file.exists() and model_file.is_file()
else:
model_path = self.models_dir / model_id
if model_path.exists() and model_path.is_dir():
return any(model_path.iterdir())
pattern = f"*{model_id}*.safetensors"
matching_files = list(self.models_dir.glob(pattern))
return len(matching_files) > 0
def is_terms_accepted(self, model_id: str) -> bool:
if not self.terms_file.exists():
return False
try:
with open(self.terms_file, 'r') as f:
terms_data = json.load(f)
return terms_data.get(model_id, {}).get('accepted', False)
except:
return False
def accept_terms(self, model_id: str) -> bool:
if model_id not in self.available_models:
return False
terms_data = {}
if self.terms_file.exists():
try:
with open(self.terms_file, 'r') as f:
terms_data = json.load(f)
except:
terms_data = {}
terms_data[model_id] = {
'accepted': True,
'accepted_at': datetime.now().isoformat(),
'model_name': self.available_models[model_id]['name'],
'license': self.available_models[model_id]['license']
}
try:
with open(self.terms_file, 'w') as f:
json.dump(terms_data, f, indent=2)
return True
except Exception as e:
print(f"Error saving terms acceptance: {e}")
return False
def download_model(self, model_id: str, progress_callback: Optional[Callable] = None) -> bool:
if model_id not in self.available_models:
return False
if not self.is_terms_accepted(model_id):
print(f"Terms not accepted for {model_id}")
self.accept_terms(model_id)
print(f"Automatically accepted terms for {model_id}")
model_info = self.available_models[model_id]
target_dir = self.models_dir
target_dir.mkdir(exist_ok=True, parents=True)
try:
print(f"Downloading {model_info['name']} to {target_dir}")
if progress_callback:
progress_callback(
0, f"Starting download of {model_info['name']}...")
from huggingface_hub import HfApi
api = HfApi()
try:
user = api.whoami()
print(f"Authenticated as: {user}")
if progress_callback:
progress_callback(10, "Authentication verified...")
except Exception as auth_error:
print(f"Not authenticated with Hugging Face: {auth_error}")
if progress_callback:
progress_callback(0, "Authentication required...")
is_docker = os.environ.get('FRAGMENTA_DOCKER', '').strip() == '1'
if is_docker:
print("Docker mode: HF authentication required. "
"Set your token via Model Setup in the browser UI, "
"or pass -e HF_TOKEN=hf_xxx to docker run.")
if progress_callback:
progress_callback(0, "HF authentication required — use Model Setup to set your token")
return False
try:
from app.core.hf_auth_dialog import show_hf_auth_dialog
success = show_hf_auth_dialog()
if not success:
print("Authentication dialog was cancelled")
if progress_callback:
progress_callback(0, "Authentication cancelled")
return False
try:
user = api.whoami()
print(f"Now authenticated as: {user}")
if progress_callback:
progress_callback(
10, "Authentication successful...")
except Exception as retry_error:
print(f"Still not authenticated: {retry_error}")
if progress_callback:
progress_callback(0, "Authentication failed")
return False
except ImportError:
print("To download models, you need to:")
print(
"1. Visit https://huggingface.co/stabilityai/stable-audio-open-small")
print("2. Accept the terms and conditions")
print("3. Log in to your Hugging Face account")
print(
"4. Get your access token from https://huggingface.co/settings/tokens")
print("5. Run: huggingface-cli login")
if progress_callback:
progress_callback(0, "Manual authentication required")
return False
if progress_callback:
progress_callback(20, "Starting file download...")
try:
from huggingface_hub import hf_hub_download
import shutil
from tqdm import tqdm
import sys
class TqdmToCallback:
def __init__(self, callback, file_index, total_files):
self.callback = callback
self.file_index = file_index
self.total_files = total_files
self.last_percent = 0
def __call__(self, t):
"""Returns a callback function for tqdm"""
def inner(bytes_amount=1):
if t.total:
file_progress = (t.n / t.total)
overall_progress = (self.file_index + file_progress) / self.total_files
percent = 20 + int(overall_progress * 70)
if percent != self.last_percent:
self.last_percent = percent
downloaded_mb = t.n / (1024 * 1024)
total_mb = t.total / (1024 * 1024)
if self.callback:
self.callback(
percent,
f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB"
)
return inner
downloaded_files = []
total_files = len(model_info['files'])
for i, file_pattern in enumerate(model_info['files']):
if progress_callback:
progress_callback(
20 + int((i / total_files) * 70),
f"Starting download of {file_pattern}..."
)
try:
if file_pattern == 'model.safetensors':
if model_id == 'stable-audio-open-small':
final_filename = 'stable-audio-open-small-model.safetensors'
elif model_id == 'stable-audio-open-1.0':
final_filename = 'stable-audio-open-model.safetensors'
else:
final_filename = f"{model_id}-model.safetensors"
else:
final_filename = f"{model_id}-{file_pattern}"
tqdm_callback = TqdmToCallback(progress_callback, i, total_files)
original_tqdm_init = tqdm.__init__
def patched_tqdm_init(self, *args, **kwargs):
original_tqdm_init(self, *args, **kwargs)
# Hook into tqdm updates
original_update = self.update
def new_update(n=1):
result = original_update(n)
if progress_callback and self.total:
file_progress = (self.n / self.total)
overall_progress = (i + file_progress) / total_files
percent = 20 + int(overall_progress * 70)
downloaded_mb = self.n / (1024 * 1024)
total_mb = self.total / (1024 * 1024)
progress_callback(
percent,
f"Downloading: {downloaded_mb:.1f}MB / {total_mb:.1f}MB"
)
return result
self.update = new_update
tqdm.__init__ = patched_tqdm_init
try:
downloaded_file = hf_hub_download(
repo_id=model_info['repo'],
filename=file_pattern,
resume_download=True
)
finally:
tqdm.__init__ = original_tqdm_init
downloaded_path = Path(downloaded_file)
final_path = target_dir / final_filename
final_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(str(downloaded_path), str(final_path))
print(f"Saved as {final_filename}")
downloaded_files.append(str(final_path))
if progress_callback:
progress_callback(
20 + int(((i + 1) / total_files) * 70),
f"Completed {file_pattern}"
)
except Exception as file_error:
print(
f"Failed to download {file_pattern}: {file_error}")
if progress_callback:
progress_callback(
0, f"Failed to download {file_pattern}")
continue
print(f"Downloaded {len(downloaded_files)} files")
if progress_callback:
progress_callback(
95, "Download completed, verifying files...")
except Exception as download_error:
print(f"Error during download: {download_error}")
if progress_callback:
progress_callback(
0, f"Download failed: {str(download_error)}")
return False
if progress_callback:
progress_callback(95, "Verifying download...")
expected_files = []
if model_id == 'stable-audio-open-small':
expected_files.append(
'stable-audio-open-small-model.safetensors')
elif model_id == 'stable-audio-open-1.0':
expected_files.append('stable-audio-open-model.safetensors')
else:
expected_files.append(f"{model_id}-model.safetensors")
files_exist = any((target_dir / expected_file).exists()
for expected_file in expected_files)
if files_exist:
if progress_callback:
progress_callback(100, "Download complete!")
print(f"Successfully downloaded {model_info['name']}")
return True
else:
if progress_callback:
progress_callback(0, "Download verification failed")
print(f"Expected files not found: {expected_files}")
return False
except Exception as e:
print(f"Error downloading {model_info['name']}: {e}")
if progress_callback:
progress_callback(0, f"Error: {str(e)}")
if "403" in str(e) and "gated repositories" in str(e).lower():
print("Token permission issue detected!")
print(
"Your Hugging Face token needs 'Read access to public gated repositories'")
print("Please:")
print("1. Go to https://huggingface.co/settings/tokens")
print("2. Edit your token or create a new one")
print("3. Enable 'Read access to public gated repositories'")
print("4. Try the download again")
elif "401" in str(e) or "restricted" in str(e).lower():
print("This model requires Hugging Face authentication.")
print("Please visit the model page and accept terms first:")
print(f"https://huggingface.co/{model_info['repo']}")
return False
def delete_model(self, model_id: str) -> bool:
deleted_something = False
if model_id == 'stable-audio-open-small':
model_file = self.models_dir / 'stable-audio-open-small-model.safetensors'
config_file = self.models_dir / 'stable-audio-open-small-config.json'
elif model_id == 'stable-audio-open-1.0':
model_file = self.models_dir / 'stable-audio-open-model.safetensors'
config_file = self.models_dir / 'stable-audio-open-1.0-config.json'
else:
model_file = self.models_dir / f"{model_id}-model.safetensors"
config_file = self.models_dir / f"{model_id}-config.json"
for file_path in [model_file, config_file]:
if file_path.exists():
try:
file_path.unlink()
print(f"Deleted {file_path.name}")
deleted_something = True
except Exception as e:
print(f"Error deleting {file_path.name}: {e}")
model_path = self.models_dir / model_id
if model_path.exists() and model_path.is_dir():
try:
shutil.rmtree(model_path)
print(f"Deleted {model_id} directory")
deleted_something = True
except Exception as e:
print(f"Error deleting {model_id} directory: {e}")
if deleted_something:
print(f"Deleted {model_id}")
return True
else:
print(f"No files found for {model_id}")
return False
def get_download_progress(self, model_id: str) -> Dict:
return {
'model_id': model_id,
'downloaded': self.is_model_downloaded(model_id),
'size': self.available_models.get(model_id, {}).get('size', 'Unknown')
}
def get_storage_info(self) -> Dict:
total_size = 0
model_count = 0
if self.models_dir.exists():
for model_id in self.available_models.keys():
if self.is_model_downloaded(model_id):
model_count += 1
for file_path in self.models_dir.rglob("*"):
if file_path.is_file():
total_size += file_path.stat().st_size
return {
'total_size_bytes': total_size,
'total_size_human': self._bytes_to_human(total_size),
'model_count': model_count,
'models_dir': str(self.models_dir)
}
def _bytes_to_human(self, bytes_value: int) -> str:
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_value < 1024.0:
return f"{bytes_value:.1f} {unit}"
bytes_value /= 1024.0
return f"{bytes_value:.1f} TB"