deasdutta commited on
Commit
33be500
·
verified ·
1 Parent(s): 19f6105

Upload app\router.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app//router.py +293 -0
app//router.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Continuum Router for ContinuumAgent Project
4
+ Routes requests between base model and patched model based on query complexity
5
+ """
6
+
7
+ import os
8
+ import time
9
+ from typing import Dict, List, Any, Optional, Tuple, Union
10
+ from runtime.gguf_lora_runtime import GGUFLoraRuntime
11
+ from runtime.difficulty_gate import DifficultyGate
12
+ from runtime.lora_mux import LoraMux
13
+
14
+ class ContinuumRouter:
15
+ """
16
+ Routes requests between base model and patched model
17
+ """
18
+
19
+ def __init__(self,
20
+ model_path: str,
21
+ registry_dir: str = "models/registry",
22
+ n_gpu_layers: int = -1):
23
+ """
24
+ Initialize the continuum router
25
+
26
+ Args:
27
+ model_path: Path to GGUF model file
28
+ registry_dir: Path to LoRA registry directory
29
+ n_gpu_layers: Number of layers to offload to GPU (-1 for all)
30
+ """
31
+ self.model_path = model_path
32
+ self.registry_dir = registry_dir
33
+ self.n_gpu_layers = n_gpu_layers
34
+
35
+ # Extract model details from path
36
+ self.model_name = os.path.basename(model_path)
37
+
38
+ # Initialize components
39
+ print("Initializing GGUF runtime...")
40
+ self.runtime = GGUFLoraRuntime(
41
+ model_path=model_path,
42
+ registry_dir=registry_dir,
43
+ n_gpu_layers=n_gpu_layers
44
+ )
45
+
46
+ print("Initializing difficulty gate...")
47
+ self.gate = DifficultyGate(
48
+ model_path=model_path,
49
+ n_gpu_layers=0 # Use CPU for gate model (lightweight)
50
+ )
51
+
52
+ print("Initializing LoRA mux...")
53
+ self.lora_mux = LoraMux(registry_dir=registry_dir)
54
+
55
+ # Statistics
56
+ self.request_count = 0
57
+ self.patch_usage_count = 0
58
+
59
+ def get_model_info(self) -> Dict[str, Any]:
60
+ """
61
+ Get model information
62
+
63
+ Returns:
64
+ Dictionary with model information
65
+ """
66
+ # Extract quantization format from model name
67
+ quant_format = "unknown"
68
+ if ".Q" in self.model_name:
69
+ quant_format = self.model_name.split(".Q")[1].split(".")[0]
70
+
71
+ # Get available patches
72
+ patches = self.list_patches()
73
+
74
+ # Create model info
75
+ return {
76
+ "name": self.model_name,
77
+ "quantization": quant_format,
78
+ "patches": patches,
79
+ "using_gpu": self.n_gpu_layers != 0
80
+ }
81
+
82
+ def list_patches(self) -> List[Dict[str, Any]]:
83
+ """
84
+ List available patches
85
+
86
+ Returns:
87
+ List of patch info dictionaries
88
+ """
89
+ return self.lora_mux.get_available_patches()
90
+
91
+ def get_active_patches(self) -> List[str]:
92
+ """
93
+ Get currently active patches
94
+
95
+ Returns:
96
+ List of active patch paths
97
+ """
98
+ return self.runtime.loaded_adapters
99
+
100
+ def load_patches(self, date_str: Optional[str] = None) -> List[str]:
101
+ """
102
+ Load patches for a specific date
103
+
104
+ Args:
105
+ date_str: Date string in YYYYMMDD format (defaults to today)
106
+
107
+ Returns:
108
+ List of loaded patch paths
109
+ """
110
+ return self.runtime.load_adapters(date_str)
111
+
112
+ def load_latest_patches(self) -> List[str]:
113
+ """
114
+ Load latest patches
115
+
116
+ Returns:
117
+ List of loaded patch paths
118
+ """
119
+ # Get latest patch
120
+ latest_patch = self.lora_mux.get_latest_patch()
121
+
122
+ if not latest_patch:
123
+ print("No patches available")
124
+ return []
125
+
126
+ # Extract date from path
127
+ path = latest_patch.get("path", "")
128
+ date_str = path.split("/")[0] if "/" in path else None
129
+
130
+ # Load patches
131
+ return self.load_patches(date_str)
132
+
133
+ def should_use_patches(self, query: str, force_patches: Optional[bool] = None) -> bool:
134
+ """
135
+ Determine if patches should be used for the query
136
+
137
+ Args:
138
+ query: Query string
139
+ force_patches: Force using or not using patches
140
+
141
+ Returns:
142
+ Boolean decision
143
+ """
144
+ # If force_patches is specified, use that decision
145
+ if force_patches is not None:
146
+ return force_patches
147
+
148
+ # Otherwise, use the gate to decide
149
+ decision = self.gate.should_use_patches(query)
150
+ return decision["needs_patches"]
151
+
152
+ def generate(self,
153
+ prompt: str,
154
+ system_prompt: Optional[str] = None,
155
+ max_tokens: int = 256,
156
+ temperature: float = 0.7,
157
+ top_p: float = 0.95,
158
+ auto_route: bool = True,
159
+ force_patches: Optional[bool] = None) -> Dict[str, Any]:
160
+ """
161
+ Generate response with appropriate model
162
+
163
+ Args:
164
+ prompt: User prompt
165
+ system_prompt: Optional system prompt
166
+ max_tokens: Maximum tokens to generate
167
+ temperature: Sampling temperature
168
+ top_p: Top-p sampling parameter
169
+ auto_route: Whether to use automatic routing
170
+ force_patches: Force using or not using patches
171
+
172
+ Returns:
173
+ Generation result
174
+ """
175
+ # Update request count
176
+ self.request_count += 1
177
+
178
+ # Determine if patches should be used
179
+ if not auto_route:
180
+ # Use patches based on force_patches (default to True if not specified)
181
+ use_patches = force_patches if force_patches is not None else True
182
+ else:
183
+ # Use gate to decide
184
+ use_patches = self.should_use_patches(prompt, force_patches)
185
+
186
+ # Generate response
187
+ start_time = time.time()
188
+
189
+ result = self.runtime.generate(
190
+ prompt=prompt,
191
+ system_prompt=system_prompt,
192
+ max_tokens=max_tokens,
193
+ temperature=temperature,
194
+ top_p=top_p,
195
+ with_adapters=use_patches
196
+ )
197
+
198
+ # Update statistics
199
+ if use_patches:
200
+ self.patch_usage_count += 1
201
+
202
+ # Format response
203
+ return {
204
+ "text": result["text"],
205
+ "elapsed_seconds": result["elapsed_seconds"],
206
+ "used_patches": use_patches,
207
+ "adapter_paths": self.runtime.loaded_adapters if use_patches else [],
208
+ "total_tokens": len(prompt.split()) + len(result["text"].split()) # Approximate
209
+ }
210
+
211
+ def benchmark(self,
212
+ queries: List[str],
213
+ with_patches: bool = True,
214
+ max_tokens: int = 256) -> Dict[str, Any]:
215
+ """
216
+ Run benchmark on a list of queries
217
+
218
+ Args:
219
+ queries: List of query strings
220
+ with_patches: Whether to use patches
221
+ max_tokens: Maximum tokens to generate
222
+
223
+ Returns:
224
+ Benchmark results
225
+ """
226
+ results = []
227
+ total_time = 0
228
+
229
+ for query in queries:
230
+ # Generate response
231
+ start_time = time.time()
232
+
233
+ response = self.runtime.generate(
234
+ prompt=query,
235
+ max_tokens=max_tokens,
236
+ with_adapters=with_patches
237
+ )
238
+
239
+ elapsed = time.time() - start_time
240
+ total_time += elapsed
241
+
242
+ # Add to results
243
+ results.append({
244
+ "query": query,
245
+ "elapsed_seconds": elapsed,
246
+ "tokens": len(response["text"].split())
247
+ })
248
+
249
+ # Calculate statistics
250
+ avg_time = total_time / len(queries) if queries else 0
251
+
252
+ return {
253
+ "num_queries": len(queries),
254
+ "total_time": total_time,
255
+ "average_time": avg_time,
256
+ "with_patches": with_patches,
257
+ "results": results
258
+ }
259
+
260
+ def compare_outputs(self,
261
+ query: str,
262
+ max_tokens: int = 256) -> Dict[str, Any]:
263
+ """
264
+ Compare outputs from base model and patched model
265
+
266
+ Args:
267
+ query: Query string
268
+ max_tokens: Maximum tokens to generate
269
+
270
+ Returns:
271
+ Comparison results
272
+ """
273
+ # Generate with base model
274
+ base_result = self.runtime.generate(
275
+ prompt=query,
276
+ max_tokens=max_tokens,
277
+ with_adapters=False
278
+ )
279
+
280
+ # Generate with patched model
281
+ patched_result = self.runtime.generate(
282
+ prompt=query,
283
+ max_tokens=max_tokens,
284
+ with_adapters=True
285
+ )
286
+
287
+ return {
288
+ "query": query,
289
+ "base_output": base_result["text"],
290
+ "patched_output": patched_result["text"],
291
+ "base_time": base_result["elapsed_seconds"],
292
+ "patched_time": patched_result["elapsed_seconds"]
293
+ }