|
|
| """
|
| Continuum Router for ContinuumAgent Project
|
| Routes requests between base model and patched model based on query complexity
|
| """
|
|
|
| import os
|
| import time
|
| from typing import Dict, List, Any, Optional, Tuple, Union
|
| from runtime.gguf_lora_runtime import GGUFLoraRuntime
|
| from runtime.difficulty_gate import DifficultyGate
|
| from runtime.lora_mux import LoraMux
|
|
|
| class ContinuumRouter:
|
| """
|
| Routes requests between base model and patched model
|
| """
|
|
|
| def __init__(self,
|
| model_path: str,
|
| registry_dir: str = "models/registry",
|
| n_gpu_layers: int = -1):
|
| """
|
| Initialize the continuum router
|
|
|
| Args:
|
| model_path: Path to GGUF model file
|
| registry_dir: Path to LoRA registry directory
|
| n_gpu_layers: Number of layers to offload to GPU (-1 for all)
|
| """
|
| self.model_path = model_path
|
| self.registry_dir = registry_dir
|
| self.n_gpu_layers = n_gpu_layers
|
|
|
|
|
| self.model_name = os.path.basename(model_path)
|
|
|
|
|
| print("Initializing GGUF runtime...")
|
| self.runtime = GGUFLoraRuntime(
|
| model_path=model_path,
|
| registry_dir=registry_dir,
|
| n_gpu_layers=n_gpu_layers
|
| )
|
|
|
| print("Initializing difficulty gate...")
|
| self.gate = DifficultyGate(
|
| model_path=model_path,
|
| n_gpu_layers=0
|
| )
|
|
|
| print("Initializing LoRA mux...")
|
| self.lora_mux = LoraMux(registry_dir=registry_dir)
|
|
|
|
|
| self.request_count = 0
|
| self.patch_usage_count = 0
|
|
|
| def get_model_info(self) -> Dict[str, Any]:
|
| """
|
| Get model information
|
|
|
| Returns:
|
| Dictionary with model information
|
| """
|
|
|
| quant_format = "unknown"
|
| if ".Q" in self.model_name:
|
| quant_format = self.model_name.split(".Q")[1].split(".")[0]
|
|
|
|
|
| patches = self.list_patches()
|
|
|
|
|
| return {
|
| "name": self.model_name,
|
| "quantization": quant_format,
|
| "patches": patches,
|
| "using_gpu": self.n_gpu_layers != 0
|
| }
|
|
|
| def list_patches(self) -> List[Dict[str, Any]]:
|
| """
|
| List available patches
|
|
|
| Returns:
|
| List of patch info dictionaries
|
| """
|
| return self.lora_mux.get_available_patches()
|
|
|
| def get_active_patches(self) -> List[str]:
|
| """
|
| Get currently active patches
|
|
|
| Returns:
|
| List of active patch paths
|
| """
|
| return self.runtime.loaded_adapters
|
|
|
| def load_patches(self, date_str: Optional[str] = None) -> List[str]:
|
| """
|
| Load patches for a specific date
|
|
|
| Args:
|
| date_str: Date string in YYYYMMDD format (defaults to today)
|
|
|
| Returns:
|
| List of loaded patch paths
|
| """
|
| return self.runtime.load_adapters(date_str)
|
|
|
| def load_latest_patches(self) -> List[str]:
|
| """
|
| Load latest patches
|
|
|
| Returns:
|
| List of loaded patch paths
|
| """
|
|
|
| latest_patch = self.lora_mux.get_latest_patch()
|
|
|
| if not latest_patch:
|
| print("No patches available")
|
| return []
|
|
|
|
|
| path = latest_patch.get("path", "")
|
| date_str = path.split("/")[0] if "/" in path else None
|
|
|
|
|
| return self.load_patches(date_str)
|
|
|
| def should_use_patches(self, query: str, force_patches: Optional[bool] = None) -> bool:
|
| """
|
| Determine if patches should be used for the query
|
|
|
| Args:
|
| query: Query string
|
| force_patches: Force using or not using patches
|
|
|
| Returns:
|
| Boolean decision
|
| """
|
|
|
| if force_patches is not None:
|
| return force_patches
|
|
|
|
|
| decision = self.gate.should_use_patches(query)
|
| return decision["needs_patches"]
|
|
|
| def generate(self,
|
| prompt: str,
|
| system_prompt: Optional[str] = None,
|
| max_tokens: int = 256,
|
| temperature: float = 0.7,
|
| top_p: float = 0.95,
|
| auto_route: bool = True,
|
| force_patches: Optional[bool] = None) -> Dict[str, Any]:
|
| """
|
| Generate response with appropriate model
|
|
|
| Args:
|
| prompt: User prompt
|
| system_prompt: Optional system prompt
|
| max_tokens: Maximum tokens to generate
|
| temperature: Sampling temperature
|
| top_p: Top-p sampling parameter
|
| auto_route: Whether to use automatic routing
|
| force_patches: Force using or not using patches
|
|
|
| Returns:
|
| Generation result
|
| """
|
|
|
| self.request_count += 1
|
|
|
|
|
| if not auto_route:
|
|
|
| use_patches = force_patches if force_patches is not None else True
|
| else:
|
|
|
| use_patches = self.should_use_patches(prompt, force_patches)
|
|
|
|
|
| start_time = time.time()
|
|
|
| result = self.runtime.generate(
|
| prompt=prompt,
|
| system_prompt=system_prompt,
|
| max_tokens=max_tokens,
|
| temperature=temperature,
|
| top_p=top_p,
|
| with_adapters=use_patches
|
| )
|
|
|
|
|
| if use_patches:
|
| self.patch_usage_count += 1
|
|
|
|
|
| return {
|
| "text": result["text"],
|
| "elapsed_seconds": result["elapsed_seconds"],
|
| "used_patches": use_patches,
|
| "adapter_paths": self.runtime.loaded_adapters if use_patches else [],
|
| "total_tokens": len(prompt.split()) + len(result["text"].split())
|
| }
|
|
|
| def benchmark(self,
|
| queries: List[str],
|
| with_patches: bool = True,
|
| max_tokens: int = 256) -> Dict[str, Any]:
|
| """
|
| Run benchmark on a list of queries
|
|
|
| Args:
|
| queries: List of query strings
|
| with_patches: Whether to use patches
|
| max_tokens: Maximum tokens to generate
|
|
|
| Returns:
|
| Benchmark results
|
| """
|
| results = []
|
| total_time = 0
|
|
|
| for query in queries:
|
|
|
| start_time = time.time()
|
|
|
| response = self.runtime.generate(
|
| prompt=query,
|
| max_tokens=max_tokens,
|
| with_adapters=with_patches
|
| )
|
|
|
| elapsed = time.time() - start_time
|
| total_time += elapsed
|
|
|
|
|
| results.append({
|
| "query": query,
|
| "elapsed_seconds": elapsed,
|
| "tokens": len(response["text"].split())
|
| })
|
|
|
|
|
| avg_time = total_time / len(queries) if queries else 0
|
|
|
| return {
|
| "num_queries": len(queries),
|
| "total_time": total_time,
|
| "average_time": avg_time,
|
| "with_patches": with_patches,
|
| "results": results
|
| }
|
|
|
| def compare_outputs(self,
|
| query: str,
|
| max_tokens: int = 256) -> Dict[str, Any]:
|
| """
|
| Compare outputs from base model and patched model
|
|
|
| Args:
|
| query: Query string
|
| max_tokens: Maximum tokens to generate
|
|
|
| Returns:
|
| Comparison results
|
| """
|
|
|
| base_result = self.runtime.generate(
|
| prompt=query,
|
| max_tokens=max_tokens,
|
| with_adapters=False
|
| )
|
|
|
|
|
| patched_result = self.runtime.generate(
|
| prompt=query,
|
| max_tokens=max_tokens,
|
| with_adapters=True
|
| )
|
|
|
| return {
|
| "query": query,
|
| "base_output": base_result["text"],
|
| "patched_output": patched_result["text"],
|
| "base_time": base_result["elapsed_seconds"],
|
| "patched_time": patched_result["elapsed_seconds"]
|
| } |