Synesthesia / runtime /model_manager.py
Ashiedu's picture
Sync unified workbench
0490201 verified
import os
import json
import time
from huggingface_hub import snapshot_download
class ModelManager:
def __init__(self, base_path="models", quantized_path="models/quantized", state_path="runtime/logs/active_model.json"):
self.base_path = base_path
self.quantized_path = quantized_path
self.state_path = state_path
os.makedirs(self.base_path, exist_ok=True)
os.makedirs(self.quantized_path, exist_ok=True)
os.makedirs(os.path.dirname(self.state_path), exist_ok=True)
def list_models(self):
"""Lists base models in /models."""
try:
return [d for d in os.listdir(self.base_path)
if os.path.isdir(os.path.join(self.base_path, d)) and d != "quantized"]
except Exception:
return []
def list_quantized(self):
"""Lists optimized models in /models/quantized."""
try:
return [f for f in os.listdir(self.quantized_path)
if os.path.isfile(os.path.join(self.quantized_path, f))]
except Exception:
return []
def download_model(self, repo_id):
"""Downloads a real model from Hugging Face Hub using the hf CLI."""
local_dir = os.path.join(self.base_path, repo_id.split("/")[-1])
print(f"Downloading {repo_id} to {local_dir} via hf CLI...")
try:
import subprocess
# Use the modern 'hf' CLI for weights validation and download
cmd = ["hf", "hub", "download", repo_id, "--local-dir", local_dir]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
return local_dir
else:
print(f"hf CLI download failed: {result.stderr}")
return None
except Exception as e:
print(f"Download failed: {e}")
return None
def verify_integrity(self, model_name):
"""Actual existence and size check for model artifacts."""
path = os.path.join(self.base_path, model_name)
if not os.path.exists(path):
# Also check quantized path
path = os.path.join(self.quantized_path, model_name)
if not os.path.exists(path):
return False
# Check for .onnx, .vmfb, or .gguf files > 1MB
valid_files = 0
if os.path.isfile(path):
if path.endswith((".onnx", ".vmfb", ".gguf")) and os.path.getsize(path) > 1 * 1024 * 1024:
valid_files += 1
else:
for root, _, files in os.walk(path):
for f in files:
if f.endswith((".onnx", ".vmfb", ".gguf")):
file_path = os.path.join(root, f)
if os.path.getsize(file_path) > 1 * 1024 * 1024:
valid_files += 1
return valid_files > 0
def load_model(self, model_name):
"""Tracks the currently active model."""
state = {"active_model": model_name, "loaded_at": str(os.times().elapsed)}
with open(self.state_path, "w") as f:
json.dump(state, f)
print(f"Model {model_name} loaded.")
return True
def unload_model(self):
"""Clears the active model state."""
if os.path.exists(self.state_path):
os.remove(self.state_path)
print("Model unloaded.")
return True
if __name__ == "__main__":
mm = ModelManager()
print(f"Models: {mm.list_models()}")
print(f"Quantized: {mm.list_quantized()}")