deasdutta commited on
Commit
dc39aaf
·
verified ·
1 Parent(s): 0a20b7b

Upload runtime\lora_mux.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. runtime//lora_mux.py +180 -0
runtime//lora_mux.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ LoraMux for ContinuumAgent Project
4
+ Dynamically loads and applies LoRA patches at inference time
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import glob
10
+ from typing import List, Dict, Any, Optional
11
+ import datetime
12
+
13
+ class LoraMux:
14
+ """Manages and applies LoRA patches dynamically"""
15
+
16
+ def __init__(self, registry_dir: str = "models/registry"):
17
+ """
18
+ Initialize the LoraMux
19
+
20
+ Args:
21
+ registry_dir: Directory containing LoRA patch registry
22
+ """
23
+ self.registry_dir = registry_dir
24
+ self.index_path = os.path.join(registry_dir, "index.json")
25
+ self.active_patches = {}
26
+
27
+ # Load registry index
28
+ self._load_registry_index()
29
+
30
+ def _load_registry_index(self) -> None:
31
+ """Load registry index file"""
32
+ if not os.path.exists(self.index_path):
33
+ print(f"Registry index not found at {self.index_path}")
34
+ self.registry_index = {"patches": []}
35
+ return
36
+
37
+ try:
38
+ with open(self.index_path, "r") as f:
39
+ self.registry_index = json.load(f)
40
+
41
+ print(f"Loaded registry with {len(self.registry_index.get('patches', []))} patches")
42
+ except Exception as e:
43
+ print(f"Error loading registry index: {e}")
44
+ self.registry_index = {"patches": []}
45
+
46
+ def get_available_patches(self, date: Optional[str] = None) -> List[Dict[str, Any]]:
47
+ """
48
+ Get list of available patches, optionally filtered by date
49
+
50
+ Args:
51
+ date: Date string in YYYYMMDD format (optional)
52
+
53
+ Returns:
54
+ List of patch info dictionaries
55
+ """
56
+ patches = self.registry_index.get("patches", [])
57
+
58
+ if date:
59
+ # Filter patches by date
60
+ patches = [p for p in patches if p.get("path", "").startswith(date)]
61
+
62
+ return patches
63
+
64
+ def get_latest_patch(self) -> Optional[Dict[str, Any]]:
65
+ """
66
+ Get the latest available patch
67
+
68
+ Returns:
69
+ Patch info dictionary or None if no patches available
70
+ """
71
+ patches = self.registry_index.get("patches", [])
72
+
73
+ if not patches:
74
+ return None
75
+
76
+ # Sort by created_at timestamp
77
+ patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
78
+
79
+ return patches[0]
80
+
81
+ def get_patches_for_date(self, date_str: Optional[str] = None) -> List[Dict[str, Any]]:
82
+ """
83
+ Get patches for a specific date
84
+
85
+ Args:
86
+ date_str: Date string in YYYYMMDD format (defaults to today)
87
+
88
+ Returns:
89
+ List of patch info dictionaries
90
+ """
91
+ if not date_str:
92
+ # Default to today's date
93
+ date_str = datetime.datetime.now().strftime("%Y%m%d")
94
+
95
+ # Get patches for date
96
+ patches = self.get_available_patches(date_str)
97
+
98
+ if not patches:
99
+ print(f"No patches found for date {date_str}")
100
+
101
+ # Try to find most recent patches before the specified date
102
+ all_patches = self.registry_index.get("patches", [])
103
+ if all_patches:
104
+ # Sort by created_at timestamp
105
+ all_patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
106
+
107
+ # Return most recent patch
108
+ print(f"Using most recent patch instead: {all_patches[0].get('path')}")
109
+ return [all_patches[0]]
110
+
111
+ return patches
112
+
113
+ def get_patch_path(self, patch_info: Dict[str, Any]) -> str:
114
+ """
115
+ Get full path to patch directory
116
+
117
+ Args:
118
+ patch_info: Patch info dictionary
119
+
120
+ Returns:
121
+ Path to patch directory
122
+ """
123
+ return os.path.join(self.registry_dir, patch_info.get("path", ""))
124
+
125
+ def load_patches(self, date_str: Optional[str] = None) -> List[str]:
126
+ """
127
+ Load patches for a specific date
128
+
129
+ Args:
130
+ date_str: Date string in YYYYMMDD format (defaults to today)
131
+
132
+ Returns:
133
+ List of loaded patch paths
134
+ """
135
+ # Get patches for date
136
+ patches = self.get_patches_for_date(date_str)
137
+
138
+ # Generate paths
139
+ patch_paths = [self.get_patch_path(p) for p in patches]
140
+
141
+ print(f"Selected {len(patch_paths)} patches for loading")
142
+
143
+ # Validate paths
144
+ valid_paths = []
145
+ for path in patch_paths:
146
+ if os.path.exists(path) and os.path.isdir(path):
147
+ adapter_model_path = os.path.join(path, "adapter_model.bin")
148
+ if os.path.exists(adapter_model_path):
149
+ valid_paths.append(path)
150
+ else:
151
+ print(f"Warning: Missing adapter_model.bin in {path}")
152
+ else:
153
+ print(f"Warning: Invalid patch path {path}")
154
+
155
+ return valid_paths
156
+
157
+ def main():
158
+ """Test LoraMux functionality"""
159
+ mux = LoraMux()
160
+
161
+ # List available patches
162
+ patches = mux.get_available_patches()
163
+ print(f"Available patches: {len(patches)}")
164
+
165
+ # Get latest patch
166
+ latest = mux.get_latest_patch()
167
+ if latest:
168
+ print(f"Latest patch: {latest.get('path')}")
169
+
170
+ # Load patches for today
171
+ today = datetime.datetime.now().strftime("%Y%m%d")
172
+ loaded = mux.load_patches(today)
173
+ print(f"Loaded patches for {today}: {len(loaded)}")
174
+
175
+ # List patch paths
176
+ for i, path in enumerate(loaded):
177
+ print(f" [{i+1}] {path}")
178
+
179
+ if __name__ == "__main__":
180
+ main()