|
|
|
|
|
"""
|
|
|
Continuum Router for ContinuumAgent Project
|
|
|
Routes requests between base model and patched model based on query complexity
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
from typing import Dict, List, Any, Optional, Tuple, Union
|
|
|
from runtime.gguf_lora_runtime import GGUFLoraRuntime
|
|
|
from runtime.difficulty_gate import DifficultyGate
|
|
|
from runtime.lora_mux import LoraMux
|
|
|
|
|
|
class ContinuumRouter:
|
|
|
"""
|
|
|
Routes requests between base model and patched model
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
model_path: str,
|
|
|
registry_dir: str = "models/registry",
|
|
|
n_gpu_layers: int = -1):
|
|
|
"""
|
|
|
Initialize the continuum router
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to GGUF model file
|
|
|
registry_dir: Path to LoRA registry directory
|
|
|
n_gpu_layers: Number of layers to offload to GPU (-1 for all)
|
|
|
"""
|
|
|
self.model_path = model_path
|
|
|
self.registry_dir = registry_dir
|
|
|
self.n_gpu_layers = n_gpu_layers
|
|
|
|
|
|
|
|
|
self.model_name = os.path.basename(model_path)
|
|
|
|
|
|
|
|
|
print("Initializing GGUF runtime...")
|
|
|
self.runtime = GGUFLoraRuntime(
|
|
|
model_path=model_path,
|
|
|
registry_dir=registry_dir,
|
|
|
n_gpu_layers=n_gpu_layers
|
|
|
)
|
|
|
|
|
|
print("Initializing difficulty gate...")
|
|
|
self.gate = DifficultyGate(
|
|
|
model_path=model_path,
|
|
|
n_gpu_layers=0
|
|
|
)
|
|
|
|
|
|
print("Initializing LoRA mux...")
|
|
|
self.lora_mux = LoraMux(registry_dir=registry_dir)
|
|
|
|
|
|
|
|
|
self.request_count = 0
|
|
|
self.patch_usage_count = 0
|
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Get model information
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with model information
|
|
|
"""
|
|
|
|
|
|
quant_format = "unknown"
|
|
|
if ".Q" in self.model_name:
|
|
|
quant_format = self.model_name.split(".Q")[1].split(".")[0]
|
|
|
|
|
|
|
|
|
patches = self.list_patches()
|
|
|
|
|
|
|
|
|
return {
|
|
|
"name": self.model_name,
|
|
|
"quantization": quant_format,
|
|
|
"patches": patches,
|
|
|
"using_gpu": self.n_gpu_layers != 0
|
|
|
}
|
|
|
|
|
|
def list_patches(self) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
List available patches
|
|
|
|
|
|
Returns:
|
|
|
List of patch info dictionaries
|
|
|
"""
|
|
|
return self.lora_mux.get_available_patches()
|
|
|
|
|
|
def get_active_patches(self) -> List[str]:
|
|
|
"""
|
|
|
Get currently active patches
|
|
|
|
|
|
Returns:
|
|
|
List of active patch paths
|
|
|
"""
|
|
|
return self.runtime.loaded_adapters
|
|
|
|
|
|
def load_patches(self, date_str: Optional[str] = None) -> List[str]:
|
|
|
"""
|
|
|
Load patches for a specific date
|
|
|
|
|
|
Args:
|
|
|
date_str: Date string in YYYYMMDD format (defaults to today)
|
|
|
|
|
|
Returns:
|
|
|
List of loaded patch paths
|
|
|
"""
|
|
|
return self.runtime.load_adapters(date_str)
|
|
|
|
|
|
def load_latest_patches(self) -> List[str]:
|
|
|
"""
|
|
|
Load latest patches
|
|
|
|
|
|
Returns:
|
|
|
List of loaded patch paths
|
|
|
"""
|
|
|
|
|
|
latest_patch = self.lora_mux.get_latest_patch()
|
|
|
|
|
|
if not latest_patch:
|
|
|
print("No patches available")
|
|
|
return []
|
|
|
|
|
|
|
|
|
path = latest_patch.get("path", "")
|
|
|
date_str = path.split("/")[0] if "/" in path else None
|
|
|
|
|
|
|
|
|
return self.load_patches(date_str)
|
|
|
|
|
|
def should_use_patches(self, query: str, force_patches: Optional[bool] = None) -> bool:
|
|
|
"""
|
|
|
Determine if patches should be used for the query
|
|
|
|
|
|
Args:
|
|
|
query: Query string
|
|
|
force_patches: Force using or not using patches
|
|
|
|
|
|
Returns:
|
|
|
Boolean decision
|
|
|
"""
|
|
|
|
|
|
if force_patches is not None:
|
|
|
return force_patches
|
|
|
|
|
|
|
|
|
decision = self.gate.should_use_patches(query)
|
|
|
return decision["needs_patches"]
|
|
|
|
|
|
def generate(self,
|
|
|
prompt: str,
|
|
|
system_prompt: Optional[str] = None,
|
|
|
max_tokens: int = 256,
|
|
|
temperature: float = 0.7,
|
|
|
top_p: float = 0.95,
|
|
|
auto_route: bool = True,
|
|
|
force_patches: Optional[bool] = None) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Generate response with appropriate model
|
|
|
|
|
|
Args:
|
|
|
prompt: User prompt
|
|
|
system_prompt: Optional system prompt
|
|
|
max_tokens: Maximum tokens to generate
|
|
|
temperature: Sampling temperature
|
|
|
top_p: Top-p sampling parameter
|
|
|
auto_route: Whether to use automatic routing
|
|
|
force_patches: Force using or not using patches
|
|
|
|
|
|
Returns:
|
|
|
Generation result
|
|
|
"""
|
|
|
|
|
|
self.request_count += 1
|
|
|
|
|
|
|
|
|
if not auto_route:
|
|
|
|
|
|
use_patches = force_patches if force_patches is not None else True
|
|
|
else:
|
|
|
|
|
|
use_patches = self.should_use_patches(prompt, force_patches)
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
result = self.runtime.generate(
|
|
|
prompt=prompt,
|
|
|
system_prompt=system_prompt,
|
|
|
max_tokens=max_tokens,
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
with_adapters=use_patches
|
|
|
)
|
|
|
|
|
|
|
|
|
if use_patches:
|
|
|
self.patch_usage_count += 1
|
|
|
|
|
|
|
|
|
return {
|
|
|
"text": result["text"],
|
|
|
"elapsed_seconds": result["elapsed_seconds"],
|
|
|
"used_patches": use_patches,
|
|
|
"adapter_paths": self.runtime.loaded_adapters if use_patches else [],
|
|
|
"total_tokens": len(prompt.split()) + len(result["text"].split())
|
|
|
}
|
|
|
|
|
|
def benchmark(self,
|
|
|
queries: List[str],
|
|
|
with_patches: bool = True,
|
|
|
max_tokens: int = 256) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Run benchmark on a list of queries
|
|
|
|
|
|
Args:
|
|
|
queries: List of query strings
|
|
|
with_patches: Whether to use patches
|
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
|
|
Returns:
|
|
|
Benchmark results
|
|
|
"""
|
|
|
results = []
|
|
|
total_time = 0
|
|
|
|
|
|
for query in queries:
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
response = self.runtime.generate(
|
|
|
prompt=query,
|
|
|
max_tokens=max_tokens,
|
|
|
with_adapters=with_patches
|
|
|
)
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
total_time += elapsed
|
|
|
|
|
|
|
|
|
results.append({
|
|
|
"query": query,
|
|
|
"elapsed_seconds": elapsed,
|
|
|
"tokens": len(response["text"].split())
|
|
|
})
|
|
|
|
|
|
|
|
|
avg_time = total_time / len(queries) if queries else 0
|
|
|
|
|
|
return {
|
|
|
"num_queries": len(queries),
|
|
|
"total_time": total_time,
|
|
|
"average_time": avg_time,
|
|
|
"with_patches": with_patches,
|
|
|
"results": results
|
|
|
}
|
|
|
|
|
|
def compare_outputs(self,
|
|
|
query: str,
|
|
|
max_tokens: int = 256) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Compare outputs from base model and patched model
|
|
|
|
|
|
Args:
|
|
|
query: Query string
|
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
|
|
Returns:
|
|
|
Comparison results
|
|
|
"""
|
|
|
|
|
|
base_result = self.runtime.generate(
|
|
|
prompt=query,
|
|
|
max_tokens=max_tokens,
|
|
|
with_adapters=False
|
|
|
)
|
|
|
|
|
|
|
|
|
patched_result = self.runtime.generate(
|
|
|
prompt=query,
|
|
|
max_tokens=max_tokens,
|
|
|
with_adapters=True
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
"query": query,
|
|
|
"base_output": base_result["text"],
|
|
|
"patched_output": patched_result["text"],
|
|
|
"base_time": base_result["elapsed_seconds"],
|
|
|
"patched_time": patched_result["elapsed_seconds"]
|
|
|
} |