File size: 5,892 Bytes
dc39aaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
#!/usr/bin/env python
"""
LoraMux for ContinuumAgent Project
Dynamically loads and applies LoRA patches at inference time
"""
import os
import json
import glob
from typing import List, Dict, Any, Optional
import datetime
class LoraMux:
"""Manages and applies LoRA patches dynamically"""
def __init__(self, registry_dir: str = "models/registry"):
"""
Initialize the LoraMux
Args:
registry_dir: Directory containing LoRA patch registry
"""
self.registry_dir = registry_dir
self.index_path = os.path.join(registry_dir, "index.json")
self.active_patches = {}
# Load registry index
self._load_registry_index()
def _load_registry_index(self) -> None:
"""Load registry index file"""
if not os.path.exists(self.index_path):
print(f"Registry index not found at {self.index_path}")
self.registry_index = {"patches": []}
return
try:
with open(self.index_path, "r") as f:
self.registry_index = json.load(f)
print(f"Loaded registry with {len(self.registry_index.get('patches', []))} patches")
except Exception as e:
print(f"Error loading registry index: {e}")
self.registry_index = {"patches": []}
def get_available_patches(self, date: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get list of available patches, optionally filtered by date
Args:
date: Date string in YYYYMMDD format (optional)
Returns:
List of patch info dictionaries
"""
patches = self.registry_index.get("patches", [])
if date:
# Filter patches by date
patches = [p for p in patches if p.get("path", "").startswith(date)]
return patches
def get_latest_patch(self) -> Optional[Dict[str, Any]]:
"""
Get the latest available patch
Returns:
Patch info dictionary or None if no patches available
"""
patches = self.registry_index.get("patches", [])
if not patches:
return None
# Sort by created_at timestamp
patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
return patches[0]
def get_patches_for_date(self, date_str: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get patches for a specific date
Args:
date_str: Date string in YYYYMMDD format (defaults to today)
Returns:
List of patch info dictionaries
"""
if not date_str:
# Default to today's date
date_str = datetime.datetime.now().strftime("%Y%m%d")
# Get patches for date
patches = self.get_available_patches(date_str)
if not patches:
print(f"No patches found for date {date_str}")
# Try to find most recent patches before the specified date
all_patches = self.registry_index.get("patches", [])
if all_patches:
# Sort by created_at timestamp
all_patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
# Return most recent patch
print(f"Using most recent patch instead: {all_patches[0].get('path')}")
return [all_patches[0]]
return patches
def get_patch_path(self, patch_info: Dict[str, Any]) -> str:
"""
Get full path to patch directory
Args:
patch_info: Patch info dictionary
Returns:
Path to patch directory
"""
return os.path.join(self.registry_dir, patch_info.get("path", ""))
def load_patches(self, date_str: Optional[str] = None) -> List[str]:
"""
Load patches for a specific date
Args:
date_str: Date string in YYYYMMDD format (defaults to today)
Returns:
List of loaded patch paths
"""
# Get patches for date
patches = self.get_patches_for_date(date_str)
# Generate paths
patch_paths = [self.get_patch_path(p) for p in patches]
print(f"Selected {len(patch_paths)} patches for loading")
# Validate paths
valid_paths = []
for path in patch_paths:
if os.path.exists(path) and os.path.isdir(path):
adapter_model_path = os.path.join(path, "adapter_model.bin")
if os.path.exists(adapter_model_path):
valid_paths.append(path)
else:
print(f"Warning: Missing adapter_model.bin in {path}")
else:
print(f"Warning: Invalid patch path {path}")
return valid_paths
def main():
"""Test LoraMux functionality"""
mux = LoraMux()
# List available patches
patches = mux.get_available_patches()
print(f"Available patches: {len(patches)}")
# Get latest patch
latest = mux.get_latest_patch()
if latest:
print(f"Latest patch: {latest.get('path')}")
# Load patches for today
today = datetime.datetime.now().strftime("%Y%m%d")
loaded = mux.load_patches(today)
print(f"Loaded patches for {today}: {len(loaded)}")
# List patch paths
for i, path in enumerate(loaded):
print(f" [{i+1}] {path}")
if __name__ == "__main__":
main() |