|
|
|
|
|
"""
|
|
|
LoraMux for ContinuumAgent Project
|
|
|
Dynamically loads and applies LoRA patches at inference time
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import glob
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
import datetime
|
|
|
|
|
|
class LoraMux:
|
|
|
"""Manages and applies LoRA patches dynamically"""
|
|
|
|
|
|
def __init__(self, registry_dir: str = "models/registry"):
|
|
|
"""
|
|
|
Initialize the LoraMux
|
|
|
|
|
|
Args:
|
|
|
registry_dir: Directory containing LoRA patch registry
|
|
|
"""
|
|
|
self.registry_dir = registry_dir
|
|
|
self.index_path = os.path.join(registry_dir, "index.json")
|
|
|
self.active_patches = {}
|
|
|
|
|
|
|
|
|
self._load_registry_index()
|
|
|
|
|
|
def _load_registry_index(self) -> None:
|
|
|
"""Load registry index file"""
|
|
|
if not os.path.exists(self.index_path):
|
|
|
print(f"Registry index not found at {self.index_path}")
|
|
|
self.registry_index = {"patches": []}
|
|
|
return
|
|
|
|
|
|
try:
|
|
|
with open(self.index_path, "r") as f:
|
|
|
self.registry_index = json.load(f)
|
|
|
|
|
|
print(f"Loaded registry with {len(self.registry_index.get('patches', []))} patches")
|
|
|
except Exception as e:
|
|
|
print(f"Error loading registry index: {e}")
|
|
|
self.registry_index = {"patches": []}
|
|
|
|
|
|
def get_available_patches(self, date: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Get list of available patches, optionally filtered by date
|
|
|
|
|
|
Args:
|
|
|
date: Date string in YYYYMMDD format (optional)
|
|
|
|
|
|
Returns:
|
|
|
List of patch info dictionaries
|
|
|
"""
|
|
|
patches = self.registry_index.get("patches", [])
|
|
|
|
|
|
if date:
|
|
|
|
|
|
patches = [p for p in patches if p.get("path", "").startswith(date)]
|
|
|
|
|
|
return patches
|
|
|
|
|
|
def get_latest_patch(self) -> Optional[Dict[str, Any]]:
|
|
|
"""
|
|
|
Get the latest available patch
|
|
|
|
|
|
Returns:
|
|
|
Patch info dictionary or None if no patches available
|
|
|
"""
|
|
|
patches = self.registry_index.get("patches", [])
|
|
|
|
|
|
if not patches:
|
|
|
return None
|
|
|
|
|
|
|
|
|
patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
|
|
|
|
|
|
return patches[0]
|
|
|
|
|
|
def get_patches_for_date(self, date_str: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Get patches for a specific date
|
|
|
|
|
|
Args:
|
|
|
date_str: Date string in YYYYMMDD format (defaults to today)
|
|
|
|
|
|
Returns:
|
|
|
List of patch info dictionaries
|
|
|
"""
|
|
|
if not date_str:
|
|
|
|
|
|
date_str = datetime.datetime.now().strftime("%Y%m%d")
|
|
|
|
|
|
|
|
|
patches = self.get_available_patches(date_str)
|
|
|
|
|
|
if not patches:
|
|
|
print(f"No patches found for date {date_str}")
|
|
|
|
|
|
|
|
|
all_patches = self.registry_index.get("patches", [])
|
|
|
if all_patches:
|
|
|
|
|
|
all_patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
|
|
|
|
|
|
|
|
|
print(f"Using most recent patch instead: {all_patches[0].get('path')}")
|
|
|
return [all_patches[0]]
|
|
|
|
|
|
return patches
|
|
|
|
|
|
def get_patch_path(self, patch_info: Dict[str, Any]) -> str:
|
|
|
"""
|
|
|
Get full path to patch directory
|
|
|
|
|
|
Args:
|
|
|
patch_info: Patch info dictionary
|
|
|
|
|
|
Returns:
|
|
|
Path to patch directory
|
|
|
"""
|
|
|
return os.path.join(self.registry_dir, patch_info.get("path", ""))
|
|
|
|
|
|
def load_patches(self, date_str: Optional[str] = None) -> List[str]:
|
|
|
"""
|
|
|
Load patches for a specific date
|
|
|
|
|
|
Args:
|
|
|
date_str: Date string in YYYYMMDD format (defaults to today)
|
|
|
|
|
|
Returns:
|
|
|
List of loaded patch paths
|
|
|
"""
|
|
|
|
|
|
patches = self.get_patches_for_date(date_str)
|
|
|
|
|
|
|
|
|
patch_paths = [self.get_patch_path(p) for p in patches]
|
|
|
|
|
|
print(f"Selected {len(patch_paths)} patches for loading")
|
|
|
|
|
|
|
|
|
valid_paths = []
|
|
|
for path in patch_paths:
|
|
|
if os.path.exists(path) and os.path.isdir(path):
|
|
|
adapter_model_path = os.path.join(path, "adapter_model.bin")
|
|
|
if os.path.exists(adapter_model_path):
|
|
|
valid_paths.append(path)
|
|
|
else:
|
|
|
print(f"Warning: Missing adapter_model.bin in {path}")
|
|
|
else:
|
|
|
print(f"Warning: Invalid patch path {path}")
|
|
|
|
|
|
return valid_paths
|
|
|
|
|
|
def main():
|
|
|
"""Test LoraMux functionality"""
|
|
|
mux = LoraMux()
|
|
|
|
|
|
|
|
|
patches = mux.get_available_patches()
|
|
|
print(f"Available patches: {len(patches)}")
|
|
|
|
|
|
|
|
|
latest = mux.get_latest_patch()
|
|
|
if latest:
|
|
|
print(f"Latest patch: {latest.get('path')}")
|
|
|
|
|
|
|
|
|
today = datetime.datetime.now().strftime("%Y%m%d")
|
|
|
loaded = mux.load_patches(today)
|
|
|
print(f"Loaded patches for {today}: {len(loaded)}")
|
|
|
|
|
|
|
|
|
for i, path in enumerate(loaded):
|
|
|
print(f" [{i+1}] {path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |