ContinuumAgent / app\router.py
deasdutta's picture
Upload app\router.py with huggingface_hub
33be500 verified
#!/usr/bin/env python
"""
Continuum Router for ContinuumAgent Project
Routes requests between base model and patched model based on query complexity
"""
import os
import time
from typing import Dict, List, Any, Optional, Tuple, Union
from runtime.gguf_lora_runtime import GGUFLoraRuntime
from runtime.difficulty_gate import DifficultyGate
from runtime.lora_mux import LoraMux
class ContinuumRouter:
"""
Routes requests between base model and patched model
"""
def __init__(self,
model_path: str,
registry_dir: str = "models/registry",
n_gpu_layers: int = -1):
"""
Initialize the continuum router
Args:
model_path: Path to GGUF model file
registry_dir: Path to LoRA registry directory
n_gpu_layers: Number of layers to offload to GPU (-1 for all)
"""
self.model_path = model_path
self.registry_dir = registry_dir
self.n_gpu_layers = n_gpu_layers
# Extract model details from path
self.model_name = os.path.basename(model_path)
# Initialize components
print("Initializing GGUF runtime...")
self.runtime = GGUFLoraRuntime(
model_path=model_path,
registry_dir=registry_dir,
n_gpu_layers=n_gpu_layers
)
print("Initializing difficulty gate...")
self.gate = DifficultyGate(
model_path=model_path,
n_gpu_layers=0 # Use CPU for gate model (lightweight)
)
print("Initializing LoRA mux...")
self.lora_mux = LoraMux(registry_dir=registry_dir)
# Statistics
self.request_count = 0
self.patch_usage_count = 0
def get_model_info(self) -> Dict[str, Any]:
"""
Get model information
Returns:
Dictionary with model information
"""
# Extract quantization format from model name
quant_format = "unknown"
if ".Q" in self.model_name:
quant_format = self.model_name.split(".Q")[1].split(".")[0]
# Get available patches
patches = self.list_patches()
# Create model info
return {
"name": self.model_name,
"quantization": quant_format,
"patches": patches,
"using_gpu": self.n_gpu_layers != 0
}
def list_patches(self) -> List[Dict[str, Any]]:
"""
List available patches
Returns:
List of patch info dictionaries
"""
return self.lora_mux.get_available_patches()
def get_active_patches(self) -> List[str]:
"""
Get currently active patches
Returns:
List of active patch paths
"""
return self.runtime.loaded_adapters
def load_patches(self, date_str: Optional[str] = None) -> List[str]:
"""
Load patches for a specific date
Args:
date_str: Date string in YYYYMMDD format (defaults to today)
Returns:
List of loaded patch paths
"""
return self.runtime.load_adapters(date_str)
def load_latest_patches(self) -> List[str]:
"""
Load latest patches
Returns:
List of loaded patch paths
"""
# Get latest patch
latest_patch = self.lora_mux.get_latest_patch()
if not latest_patch:
print("No patches available")
return []
# Extract date from path
path = latest_patch.get("path", "")
date_str = path.split("/")[0] if "/" in path else None
# Load patches
return self.load_patches(date_str)
def should_use_patches(self, query: str, force_patches: Optional[bool] = None) -> bool:
"""
Determine if patches should be used for the query
Args:
query: Query string
force_patches: Force using or not using patches
Returns:
Boolean decision
"""
# If force_patches is specified, use that decision
if force_patches is not None:
return force_patches
# Otherwise, use the gate to decide
decision = self.gate.should_use_patches(query)
return decision["needs_patches"]
def generate(self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.95,
auto_route: bool = True,
force_patches: Optional[bool] = None) -> Dict[str, Any]:
"""
Generate response with appropriate model
Args:
prompt: User prompt
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
auto_route: Whether to use automatic routing
force_patches: Force using or not using patches
Returns:
Generation result
"""
# Update request count
self.request_count += 1
# Determine if patches should be used
if not auto_route:
# Use patches based on force_patches (default to True if not specified)
use_patches = force_patches if force_patches is not None else True
else:
# Use gate to decide
use_patches = self.should_use_patches(prompt, force_patches)
# Generate response
start_time = time.time()
result = self.runtime.generate(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
with_adapters=use_patches
)
# Update statistics
if use_patches:
self.patch_usage_count += 1
# Format response
return {
"text": result["text"],
"elapsed_seconds": result["elapsed_seconds"],
"used_patches": use_patches,
"adapter_paths": self.runtime.loaded_adapters if use_patches else [],
"total_tokens": len(prompt.split()) + len(result["text"].split()) # Approximate
}
def benchmark(self,
queries: List[str],
with_patches: bool = True,
max_tokens: int = 256) -> Dict[str, Any]:
"""
Run benchmark on a list of queries
Args:
queries: List of query strings
with_patches: Whether to use patches
max_tokens: Maximum tokens to generate
Returns:
Benchmark results
"""
results = []
total_time = 0
for query in queries:
# Generate response
start_time = time.time()
response = self.runtime.generate(
prompt=query,
max_tokens=max_tokens,
with_adapters=with_patches
)
elapsed = time.time() - start_time
total_time += elapsed
# Add to results
results.append({
"query": query,
"elapsed_seconds": elapsed,
"tokens": len(response["text"].split())
})
# Calculate statistics
avg_time = total_time / len(queries) if queries else 0
return {
"num_queries": len(queries),
"total_time": total_time,
"average_time": avg_time,
"with_patches": with_patches,
"results": results
}
def compare_outputs(self,
query: str,
max_tokens: int = 256) -> Dict[str, Any]:
"""
Compare outputs from base model and patched model
Args:
query: Query string
max_tokens: Maximum tokens to generate
Returns:
Comparison results
"""
# Generate with base model
base_result = self.runtime.generate(
prompt=query,
max_tokens=max_tokens,
with_adapters=False
)
# Generate with patched model
patched_result = self.runtime.generate(
prompt=query,
max_tokens=max_tokens,
with_adapters=True
)
return {
"query": query,
"base_output": base_result["text"],
"patched_output": patched_result["text"],
"base_time": base_result["elapsed_seconds"],
"patched_time": patched_result["elapsed_seconds"]
}