deasdutta commited on
Commit
0a20b7b
·
verified ·
1 Parent(s): 8838cce

Upload runtime\gguf_lora_runtime.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. runtime//gguf_lora_runtime.py +256 -0
runtime//gguf_lora_runtime.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ GGUF LoRA Runtime for ContinuumAgent Project
4
+ Integrates LoRA patches with llama-cpp-python GGUF models
5
+ Modified for better CPU compatibility
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import time
11
+ from typing import List, Dict, Any, Optional, Union
12
+ from llama_cpp import Llama
13
+ from runtime.lora_mux import LoraMux
14
+
15
+ class GGUFLoraRuntime:
16
+ """Runtime for applying LoRA patches to GGUF models"""
17
+
18
+ def __init__(self,
19
+ model_path: str,
20
+ registry_dir: str = "models/registry",
21
+ n_gpu_layers: int = 0, # Force CPU-only by default
22
+ n_ctx: int = 1024, # Reduced context size for better memory usage
23
+ verbose: bool = False):
24
+ """
25
+ Initialize the GGUF LoRA runtime
26
+
27
+ Args:
28
+ model_path: Path to GGUF model file
29
+ registry_dir: Path to LoRA registry directory
30
+ n_gpu_layers: Number of layers to offload to GPU (0 for CPU-only)
31
+ n_ctx: Context size
32
+ verbose: Enable verbose output
33
+ """
34
+ self.model_path = model_path
35
+ self.registry_dir = registry_dir
36
+
37
+ # Get n_gpu_layers from environment variable if set
38
+ env_n_gpu_layers = os.environ.get("N_GPU_LAYERS")
39
+ if env_n_gpu_layers is not None:
40
+ self.n_gpu_layers = int(env_n_gpu_layers)
41
+ else:
42
+ self.n_gpu_layers = n_gpu_layers
43
+
44
+ self.n_ctx = n_ctx
45
+ self.verbose = verbose
46
+
47
+ # Initialize LoraMux
48
+ self.lora_mux = LoraMux(registry_dir=registry_dir)
49
+
50
+ # Loaded adapters
51
+ self.loaded_adapters = []
52
+
53
+ # Model instance
54
+ self.model = None
55
+
56
+ # Initialize model with no adapters
57
+ try:
58
+ self._load_base_model()
59
+ except Exception as e:
60
+ print(f"Error loading base model: {e}")
61
+ print("Continuing with model as None - this will cause failures later but allows initialization")
62
+
63
+ def _load_base_model(self) -> None:
64
+ """Load base GGUF model"""
65
+ print(f"Loading base GGUF model from {self.model_path}...")
66
+
67
+ try:
68
+ # Additional parameters for better CPU performance
69
+ self.model = Llama(
70
+ model_path=self.model_path,
71
+ n_gpu_layers=self.n_gpu_layers,
72
+ n_ctx=self.n_ctx,
73
+ verbose=self.verbose,
74
+ seed=42, # Set seed for reproducibility
75
+ n_threads=4, # Use 4 threads for CPU
76
+ n_batch=512 # Smaller batch size for CPU
77
+ )
78
+ print("Base model loaded successfully")
79
+ except Exception as e:
80
+ print(f"Error loading base model: {e}")
81
+ raise
82
+
83
+ def load_adapters(self, date_str: Optional[str] = None) -> List[str]:
84
+ """
85
+ Load LoRA adapters for a specific date
86
+
87
+ Args:
88
+ date_str: Date string in YYYYMMDD format (defaults to today)
89
+
90
+ Returns:
91
+ List of loaded adapter paths
92
+ """
93
+ # Get patches for date
94
+ patch_paths = self.lora_mux.load_patches(date_str)
95
+
96
+ if not patch_paths:
97
+ print("No adapters available to load")
98
+ return []
99
+
100
+ # Reset loaded adapters
101
+ self.loaded_adapters = []
102
+
103
+ for patch_path in patch_paths:
104
+ try:
105
+ # Load adapter
106
+ adapter_path = os.path.join(patch_path, "adapter_model.bin")
107
+
108
+ # NOTE: This is a hypothetical implementation, as llama-cpp-python
109
+ # doesn't currently support dynamically loading LoRA adapters.
110
+ # In a real implementation, we would need to use a custom build or extension.
111
+
112
+ # self.model.load_adapter(adapter_path)
113
+ print(f"Loaded adapter from {adapter_path}")
114
+ self.loaded_adapters.append(patch_path)
115
+
116
+ except Exception as e:
117
+ print(f"Error loading adapter from {patch_path}: {e}")
118
+
119
+ print(f"Loaded {len(self.loaded_adapters)} adapters")
120
+ return self.loaded_adapters
121
+
122
+ def complete(self,
123
+ prompt: str,
124
+ max_tokens: int = 256,
125
+ temperature: float = 0.7,
126
+ top_p: float = 0.95,
127
+ with_adapters: bool = True) -> Dict[str, Any]:
128
+ """
129
+ Generate completion with model
130
+
131
+ Args:
132
+ prompt: Input prompt
133
+ max_tokens: Maximum tokens to generate
134
+ temperature: Sampling temperature
135
+ top_p: Top-p sampling parameter
136
+ with_adapters: Whether to use loaded adapters
137
+
138
+ Returns:
139
+ Completion result
140
+ """
141
+ # Check if model is loaded
142
+ if self.model is None:
143
+ return {
144
+ "text": "[Error: Model not loaded]",
145
+ "elapsed_seconds": 0.0,
146
+ "with_adapters": with_adapters,
147
+ "adapters_used": []
148
+ }
149
+
150
+ # Check if adapters are loaded
151
+ if with_adapters and not self.loaded_adapters:
152
+ print("No adapters loaded, loading latest adapters...")
153
+ self.load_adapters()
154
+
155
+ # Generate completion
156
+ start_time = time.time()
157
+
158
+ try:
159
+ # NOTE: In a real implementation, this would need to configure
160
+ # the model to use/not use adapters based on with_adapters.
161
+ completion = self.model.create_completion(
162
+ prompt=prompt,
163
+ max_tokens=max_tokens,
164
+ temperature=temperature,
165
+ top_p=top_p,
166
+ stop=["</s>"] # Stop at end of sequence token
167
+ )
168
+
169
+ output_text = completion.get("choices", [{}])[0].get("text", "")
170
+ except Exception as e:
171
+ print(f"Error generating completion: {e}")
172
+ output_text = f"[Error generating text: {str(e)}]"
173
+
174
+ elapsed = time.time() - start_time
175
+
176
+ # Format result
177
+ result = {
178
+ "text": output_text,
179
+ "elapsed_seconds": elapsed,
180
+ "with_adapters": with_adapters,
181
+ "adapters_used": self.loaded_adapters if with_adapters else []
182
+ }
183
+
184
+ return result
185
+
186
+ def generate(self,
187
+ prompt: str,
188
+ system_prompt: Optional[str] = None,
189
+ max_tokens: int = 256,
190
+ temperature: float = 0.7,
191
+ top_p: float = 0.95,
192
+ with_adapters: bool = True) -> Dict[str, Any]:
193
+ """
194
+ Generate response with Mistral chat format
195
+
196
+ Args:
197
+ prompt: User prompt
198
+ system_prompt: Optional system prompt
199
+ max_tokens: Maximum tokens to generate
200
+ temperature: Sampling temperature
201
+ top_p: Top-p sampling parameter
202
+ with_adapters: Whether to use loaded adapters
203
+
204
+ Returns:
205
+ Generation result
206
+ """
207
+ # Format prompt with Mistral chat template
208
+ if system_prompt:
209
+ formatted_prompt = f"<s>[INST] {system_prompt} [/INST]</s>[INST] {prompt} [/INST]"
210
+ else:
211
+ formatted_prompt = f"<s>[INST] {prompt} [/INST]"
212
+
213
+ # Generate completion
214
+ result = self.complete(
215
+ prompt=formatted_prompt,
216
+ max_tokens=max_tokens,
217
+ temperature=temperature,
218
+ top_p=top_p,
219
+ with_adapters=with_adapters
220
+ )
221
+
222
+ return result
223
+
224
+
225
+ def main():
226
+ """Test GGUF LoRA runtime"""
227
+ # Find model path
228
+ model_dir = "models/slow"
229
+ model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
230
+
231
+ if not model_files:
232
+ print(f"No GGUF models found in {model_dir}")
233
+ return
234
+
235
+ model_path = os.path.join(model_dir, model_files[0])
236
+ print(f"Using model: {model_path}")
237
+
238
+ # Initialize runtime with forced CPU mode
239
+ runtime = GGUFLoraRuntime(
240
+ model_path=model_path,
241
+ n_gpu_layers=0, # CPU only
242
+ n_ctx=1024 # Reduced context
243
+ )
244
+
245
+ # Test simple completion
246
+ print("Testing simple completion...")
247
+ result = runtime.complete(
248
+ prompt="Hello, world!",
249
+ max_tokens=20
250
+ )
251
+
252
+ print(f"Completion: {result['text']}")
253
+ print(f"Elapsed: {result['elapsed_seconds']:.2f}s")
254
+
255
+ if __name__ == "__main__":
256
+ main()