ContinuumAgent / runtime\lora_mux.py
deasdutta's picture
Upload runtime\lora_mux.py with huggingface_hub
dc39aaf verified
#!/usr/bin/env python
"""
LoraMux for ContinuumAgent Project
Dynamically loads and applies LoRA patches at inference time
"""
import os
import json
import glob
from typing import List, Dict, Any, Optional
import datetime
class LoraMux:
"""Manages and applies LoRA patches dynamically"""
def __init__(self, registry_dir: str = "models/registry"):
"""
Initialize the LoraMux
Args:
registry_dir: Directory containing LoRA patch registry
"""
self.registry_dir = registry_dir
self.index_path = os.path.join(registry_dir, "index.json")
self.active_patches = {}
# Load registry index
self._load_registry_index()
def _load_registry_index(self) -> None:
"""Load registry index file"""
if not os.path.exists(self.index_path):
print(f"Registry index not found at {self.index_path}")
self.registry_index = {"patches": []}
return
try:
with open(self.index_path, "r") as f:
self.registry_index = json.load(f)
print(f"Loaded registry with {len(self.registry_index.get('patches', []))} patches")
except Exception as e:
print(f"Error loading registry index: {e}")
self.registry_index = {"patches": []}
def get_available_patches(self, date: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get list of available patches, optionally filtered by date
Args:
date: Date string in YYYYMMDD format (optional)
Returns:
List of patch info dictionaries
"""
patches = self.registry_index.get("patches", [])
if date:
# Filter patches by date
patches = [p for p in patches if p.get("path", "").startswith(date)]
return patches
def get_latest_patch(self) -> Optional[Dict[str, Any]]:
"""
Get the latest available patch
Returns:
Patch info dictionary or None if no patches available
"""
patches = self.registry_index.get("patches", [])
if not patches:
return None
# Sort by created_at timestamp
patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
return patches[0]
def get_patches_for_date(self, date_str: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get patches for a specific date
Args:
date_str: Date string in YYYYMMDD format (defaults to today)
Returns:
List of patch info dictionaries
"""
if not date_str:
# Default to today's date
date_str = datetime.datetime.now().strftime("%Y%m%d")
# Get patches for date
patches = self.get_available_patches(date_str)
if not patches:
print(f"No patches found for date {date_str}")
# Try to find most recent patches before the specified date
all_patches = self.registry_index.get("patches", [])
if all_patches:
# Sort by created_at timestamp
all_patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
# Return most recent patch
print(f"Using most recent patch instead: {all_patches[0].get('path')}")
return [all_patches[0]]
return patches
def get_patch_path(self, patch_info: Dict[str, Any]) -> str:
"""
Get full path to patch directory
Args:
patch_info: Patch info dictionary
Returns:
Path to patch directory
"""
return os.path.join(self.registry_dir, patch_info.get("path", ""))
def load_patches(self, date_str: Optional[str] = None) -> List[str]:
"""
Load patches for a specific date
Args:
date_str: Date string in YYYYMMDD format (defaults to today)
Returns:
List of loaded patch paths
"""
# Get patches for date
patches = self.get_patches_for_date(date_str)
# Generate paths
patch_paths = [self.get_patch_path(p) for p in patches]
print(f"Selected {len(patch_paths)} patches for loading")
# Validate paths
valid_paths = []
for path in patch_paths:
if os.path.exists(path) and os.path.isdir(path):
adapter_model_path = os.path.join(path, "adapter_model.bin")
if os.path.exists(adapter_model_path):
valid_paths.append(path)
else:
print(f"Warning: Missing adapter_model.bin in {path}")
else:
print(f"Warning: Invalid patch path {path}")
return valid_paths
def main():
"""Test LoraMux functionality"""
mux = LoraMux()
# List available patches
patches = mux.get_available_patches()
print(f"Available patches: {len(patches)}")
# Get latest patch
latest = mux.get_latest_patch()
if latest:
print(f"Latest patch: {latest.get('path')}")
# Load patches for today
today = datetime.datetime.now().strftime("%Y%m%d")
loaded = mux.load_patches(today)
print(f"Loaded patches for {today}: {len(loaded)}")
# List patch paths
for i, path in enumerate(loaded):
print(f" [{i+1}] {path}")
if __name__ == "__main__":
main()