#!/usr/bin/env python """ Continuum Router for ContinuumAgent Project Routes requests between base model and patched model based on query complexity """ import os import time from typing import Dict, List, Any, Optional, Tuple, Union from runtime.gguf_lora_runtime import GGUFLoraRuntime from runtime.difficulty_gate import DifficultyGate from runtime.lora_mux import LoraMux class ContinuumRouter: """ Routes requests between base model and patched model """ def __init__(self, model_path: str, registry_dir: str = "models/registry", n_gpu_layers: int = -1): """ Initialize the continuum router Args: model_path: Path to GGUF model file registry_dir: Path to LoRA registry directory n_gpu_layers: Number of layers to offload to GPU (-1 for all) """ self.model_path = model_path self.registry_dir = registry_dir self.n_gpu_layers = n_gpu_layers # Extract model details from path self.model_name = os.path.basename(model_path) # Initialize components print("Initializing GGUF runtime...") self.runtime = GGUFLoraRuntime( model_path=model_path, registry_dir=registry_dir, n_gpu_layers=n_gpu_layers ) print("Initializing difficulty gate...") self.gate = DifficultyGate( model_path=model_path, n_gpu_layers=0 # Use CPU for gate model (lightweight) ) print("Initializing LoRA mux...") self.lora_mux = LoraMux(registry_dir=registry_dir) # Statistics self.request_count = 0 self.patch_usage_count = 0 def get_model_info(self) -> Dict[str, Any]: """ Get model information Returns: Dictionary with model information """ # Extract quantization format from model name quant_format = "unknown" if ".Q" in self.model_name: quant_format = self.model_name.split(".Q")[1].split(".")[0] # Get available patches patches = self.list_patches() # Create model info return { "name": self.model_name, "quantization": quant_format, "patches": patches, "using_gpu": self.n_gpu_layers != 0 } def list_patches(self) -> List[Dict[str, Any]]: """ List available patches Returns: List of patch info dictionaries """ return self.lora_mux.get_available_patches() def get_active_patches(self) -> List[str]: """ Get currently active patches Returns: List of active patch paths """ return self.runtime.loaded_adapters def load_patches(self, date_str: Optional[str] = None) -> List[str]: """ Load patches for a specific date Args: date_str: Date string in YYYYMMDD format (defaults to today) Returns: List of loaded patch paths """ return self.runtime.load_adapters(date_str) def load_latest_patches(self) -> List[str]: """ Load latest patches Returns: List of loaded patch paths """ # Get latest patch latest_patch = self.lora_mux.get_latest_patch() if not latest_patch: print("No patches available") return [] # Extract date from path path = latest_patch.get("path", "") date_str = path.split("/")[0] if "/" in path else None # Load patches return self.load_patches(date_str) def should_use_patches(self, query: str, force_patches: Optional[bool] = None) -> bool: """ Determine if patches should be used for the query Args: query: Query string force_patches: Force using or not using patches Returns: Boolean decision """ # If force_patches is specified, use that decision if force_patches is not None: return force_patches # Otherwise, use the gate to decide decision = self.gate.should_use_patches(query) return decision["needs_patches"] def generate(self, prompt: str, system_prompt: Optional[str] = None, max_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.95, auto_route: bool = True, force_patches: Optional[bool] = None) -> Dict[str, Any]: """ Generate response with appropriate model Args: prompt: User prompt system_prompt: Optional system prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Top-p sampling parameter auto_route: Whether to use automatic routing force_patches: Force using or not using patches Returns: Generation result """ # Update request count self.request_count += 1 # Determine if patches should be used if not auto_route: # Use patches based on force_patches (default to True if not specified) use_patches = force_patches if force_patches is not None else True else: # Use gate to decide use_patches = self.should_use_patches(prompt, force_patches) # Generate response start_time = time.time() result = self.runtime.generate( prompt=prompt, system_prompt=system_prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, with_adapters=use_patches ) # Update statistics if use_patches: self.patch_usage_count += 1 # Format response return { "text": result["text"], "elapsed_seconds": result["elapsed_seconds"], "used_patches": use_patches, "adapter_paths": self.runtime.loaded_adapters if use_patches else [], "total_tokens": len(prompt.split()) + len(result["text"].split()) # Approximate } def benchmark(self, queries: List[str], with_patches: bool = True, max_tokens: int = 256) -> Dict[str, Any]: """ Run benchmark on a list of queries Args: queries: List of query strings with_patches: Whether to use patches max_tokens: Maximum tokens to generate Returns: Benchmark results """ results = [] total_time = 0 for query in queries: # Generate response start_time = time.time() response = self.runtime.generate( prompt=query, max_tokens=max_tokens, with_adapters=with_patches ) elapsed = time.time() - start_time total_time += elapsed # Add to results results.append({ "query": query, "elapsed_seconds": elapsed, "tokens": len(response["text"].split()) }) # Calculate statistics avg_time = total_time / len(queries) if queries else 0 return { "num_queries": len(queries), "total_time": total_time, "average_time": avg_time, "with_patches": with_patches, "results": results } def compare_outputs(self, query: str, max_tokens: int = 256) -> Dict[str, Any]: """ Compare outputs from base model and patched model Args: query: Query string max_tokens: Maximum tokens to generate Returns: Comparison results """ # Generate with base model base_result = self.runtime.generate( prompt=query, max_tokens=max_tokens, with_adapters=False ) # Generate with patched model patched_result = self.runtime.generate( prompt=query, max_tokens=max_tokens, with_adapters=True ) return { "query": query, "base_output": base_result["text"], "patched_output": patched_result["text"], "base_time": base_result["elapsed_seconds"], "patched_time": patched_result["elapsed_seconds"] }