#!/usr/bin/env python """ LoraMux for ContinuumAgent Project Dynamically loads and applies LoRA patches at inference time """ import os import json import glob from typing import List, Dict, Any, Optional import datetime class LoraMux: """Manages and applies LoRA patches dynamically""" def __init__(self, registry_dir: str = "models/registry"): """ Initialize the LoraMux Args: registry_dir: Directory containing LoRA patch registry """ self.registry_dir = registry_dir self.index_path = os.path.join(registry_dir, "index.json") self.active_patches = {} # Load registry index self._load_registry_index() def _load_registry_index(self) -> None: """Load registry index file""" if not os.path.exists(self.index_path): print(f"Registry index not found at {self.index_path}") self.registry_index = {"patches": []} return try: with open(self.index_path, "r") as f: self.registry_index = json.load(f) print(f"Loaded registry with {len(self.registry_index.get('patches', []))} patches") except Exception as e: print(f"Error loading registry index: {e}") self.registry_index = {"patches": []} def get_available_patches(self, date: Optional[str] = None) -> List[Dict[str, Any]]: """ Get list of available patches, optionally filtered by date Args: date: Date string in YYYYMMDD format (optional) Returns: List of patch info dictionaries """ patches = self.registry_index.get("patches", []) if date: # Filter patches by date patches = [p for p in patches if p.get("path", "").startswith(date)] return patches def get_latest_patch(self) -> Optional[Dict[str, Any]]: """ Get the latest available patch Returns: Patch info dictionary or None if no patches available """ patches = self.registry_index.get("patches", []) if not patches: return None # Sort by created_at timestamp patches.sort(key=lambda p: p.get("created_at", ""), reverse=True) return patches[0] def get_patches_for_date(self, date_str: Optional[str] = None) -> List[Dict[str, Any]]: """ Get patches for a specific date Args: date_str: Date string in YYYYMMDD format (defaults to today) Returns: List of patch info dictionaries """ if not date_str: # Default to today's date date_str = datetime.datetime.now().strftime("%Y%m%d") # Get patches for date patches = self.get_available_patches(date_str) if not patches: print(f"No patches found for date {date_str}") # Try to find most recent patches before the specified date all_patches = self.registry_index.get("patches", []) if all_patches: # Sort by created_at timestamp all_patches.sort(key=lambda p: p.get("created_at", ""), reverse=True) # Return most recent patch print(f"Using most recent patch instead: {all_patches[0].get('path')}") return [all_patches[0]] return patches def get_patch_path(self, patch_info: Dict[str, Any]) -> str: """ Get full path to patch directory Args: patch_info: Patch info dictionary Returns: Path to patch directory """ return os.path.join(self.registry_dir, patch_info.get("path", "")) def load_patches(self, date_str: Optional[str] = None) -> List[str]: """ Load patches for a specific date Args: date_str: Date string in YYYYMMDD format (defaults to today) Returns: List of loaded patch paths """ # Get patches for date patches = self.get_patches_for_date(date_str) # Generate paths patch_paths = [self.get_patch_path(p) for p in patches] print(f"Selected {len(patch_paths)} patches for loading") # Validate paths valid_paths = [] for path in patch_paths: if os.path.exists(path) and os.path.isdir(path): adapter_model_path = os.path.join(path, "adapter_model.bin") if os.path.exists(adapter_model_path): valid_paths.append(path) else: print(f"Warning: Missing adapter_model.bin in {path}") else: print(f"Warning: Invalid patch path {path}") return valid_paths def main(): """Test LoraMux functionality""" mux = LoraMux() # List available patches patches = mux.get_available_patches() print(f"Available patches: {len(patches)}") # Get latest patch latest = mux.get_latest_patch() if latest: print(f"Latest patch: {latest.get('path')}") # Load patches for today today = datetime.datetime.now().strftime("%Y%m%d") loaded = mux.load_patches(today) print(f"Loaded patches for {today}: {len(loaded)}") # List patch paths for i, path in enumerate(loaded): print(f" [{i+1}] {path}") if __name__ == "__main__": main()