File size: 5,892 Bytes
dc39aaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python
"""

LoraMux for ContinuumAgent Project

Dynamically loads and applies LoRA patches at inference time

"""

import os
import json
import glob
from typing import List, Dict, Any, Optional
import datetime

class LoraMux:
    """Manages and applies LoRA patches dynamically"""
    
    def __init__(self, registry_dir: str = "models/registry"):
        """

        Initialize the LoraMux

        

        Args:

            registry_dir: Directory containing LoRA patch registry

        """
        self.registry_dir = registry_dir
        self.index_path = os.path.join(registry_dir, "index.json")
        self.active_patches = {}
        
        # Load registry index
        self._load_registry_index()
    
    def _load_registry_index(self) -> None:
        """Load registry index file"""
        if not os.path.exists(self.index_path):
            print(f"Registry index not found at {self.index_path}")
            self.registry_index = {"patches": []}
            return
        
        try:
            with open(self.index_path, "r") as f:
                self.registry_index = json.load(f)
                
            print(f"Loaded registry with {len(self.registry_index.get('patches', []))} patches")
        except Exception as e:
            print(f"Error loading registry index: {e}")
            self.registry_index = {"patches": []}
    
    def get_available_patches(self, date: Optional[str] = None) -> List[Dict[str, Any]]:
        """

        Get list of available patches, optionally filtered by date

        

        Args:

            date: Date string in YYYYMMDD format (optional)

            

        Returns:

            List of patch info dictionaries

        """
        patches = self.registry_index.get("patches", [])
        
        if date:
            # Filter patches by date
            patches = [p for p in patches if p.get("path", "").startswith(date)]
        
        return patches
    
    def get_latest_patch(self) -> Optional[Dict[str, Any]]:
        """

        Get the latest available patch

        

        Returns:

            Patch info dictionary or None if no patches available

        """
        patches = self.registry_index.get("patches", [])
        
        if not patches:
            return None
        
        # Sort by created_at timestamp
        patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
        
        return patches[0]
    
    def get_patches_for_date(self, date_str: Optional[str] = None) -> List[Dict[str, Any]]:
        """

        Get patches for a specific date

        

        Args:

            date_str: Date string in YYYYMMDD format (defaults to today)

            

        Returns:

            List of patch info dictionaries

        """
        if not date_str:
            # Default to today's date
            date_str = datetime.datetime.now().strftime("%Y%m%d")
        
        # Get patches for date
        patches = self.get_available_patches(date_str)
        
        if not patches:
            print(f"No patches found for date {date_str}")
            
            # Try to find most recent patches before the specified date
            all_patches = self.registry_index.get("patches", [])
            if all_patches:
                # Sort by created_at timestamp
                all_patches.sort(key=lambda p: p.get("created_at", ""), reverse=True)
                
                # Return most recent patch
                print(f"Using most recent patch instead: {all_patches[0].get('path')}")
                return [all_patches[0]]
        
        return patches
    
    def get_patch_path(self, patch_info: Dict[str, Any]) -> str:
        """

        Get full path to patch directory

        

        Args:

            patch_info: Patch info dictionary

            

        Returns:

            Path to patch directory

        """
        return os.path.join(self.registry_dir, patch_info.get("path", ""))
    
    def load_patches(self, date_str: Optional[str] = None) -> List[str]:
        """

        Load patches for a specific date

        

        Args:

            date_str: Date string in YYYYMMDD format (defaults to today)

            

        Returns:

            List of loaded patch paths

        """
        # Get patches for date
        patches = self.get_patches_for_date(date_str)
        
        # Generate paths
        patch_paths = [self.get_patch_path(p) for p in patches]
        
        print(f"Selected {len(patch_paths)} patches for loading")
        
        # Validate paths
        valid_paths = []
        for path in patch_paths:
            if os.path.exists(path) and os.path.isdir(path):
                adapter_model_path = os.path.join(path, "adapter_model.bin")
                if os.path.exists(adapter_model_path):
                    valid_paths.append(path)
                else:
                    print(f"Warning: Missing adapter_model.bin in {path}")
            else:
                print(f"Warning: Invalid patch path {path}")
        
        return valid_paths

def main():
    """Test LoraMux functionality"""
    mux = LoraMux()
    
    # List available patches
    patches = mux.get_available_patches()
    print(f"Available patches: {len(patches)}")
    
    # Get latest patch
    latest = mux.get_latest_patch()
    if latest:
        print(f"Latest patch: {latest.get('path')}")
    
    # Load patches for today
    today = datetime.datetime.now().strftime("%Y%m%d")
    loaded = mux.load_patches(today)
    print(f"Loaded patches for {today}: {len(loaded)}")
    
    # List patch paths
    for i, path in enumerate(loaded):
        print(f"  [{i+1}] {path}")

if __name__ == "__main__":
    main()