|
|
|
|
|
"""
|
|
|
Difficulty Gate for ContinuumAgent Project
|
|
|
Smart routing system to determine whether to use patches based on query complexity
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
|
import numpy as np
|
|
|
from llama_cpp import Llama
|
|
|
|
|
|
class DifficultyGate:
|
|
|
"""
|
|
|
Smart routing system to determine whether to use patches based on query complexity
|
|
|
Uses a simple heuristic approach for initial implementation, can be replaced with a learned classifier
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
model_path: str,
|
|
|
gate_threshold: float = 0.7,
|
|
|
cache_dir: str = "models/gates",
|
|
|
n_gpu_layers: int = 0):
|
|
|
"""
|
|
|
Initialize the difficulty gate
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to GGUF model file
|
|
|
gate_threshold: Threshold for routing to patched model (0.0-1.0)
|
|
|
cache_dir: Directory for caching gate decisions
|
|
|
n_gpu_layers: Number of layers to offload to GPU
|
|
|
"""
|
|
|
self.model_path = model_path
|
|
|
self.gate_threshold = gate_threshold
|
|
|
self.cache_dir = cache_dir
|
|
|
self.n_gpu_layers = n_gpu_layers
|
|
|
|
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
self.cache_file = os.path.join(cache_dir, "gate_cache.json")
|
|
|
|
|
|
|
|
|
self.decision_cache = self._load_cache()
|
|
|
|
|
|
|
|
|
self._init_gate_model()
|
|
|
|
|
|
def _init_gate_model(self) -> None:
|
|
|
"""Initialize small gate model"""
|
|
|
try:
|
|
|
print(f"Loading gate model from {self.model_path}...")
|
|
|
self.gate_model = Llama(
|
|
|
model_path=self.model_path,
|
|
|
n_gpu_layers=self.n_gpu_layers,
|
|
|
n_ctx=512
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"Error loading gate model: {e}")
|
|
|
self.gate_model = None
|
|
|
|
|
|
def _load_cache(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Load decision cache from file
|
|
|
|
|
|
Returns:
|
|
|
Cache dictionary
|
|
|
"""
|
|
|
if os.path.exists(self.cache_file):
|
|
|
try:
|
|
|
with open(self.cache_file, "r") as f:
|
|
|
cache = json.load(f)
|
|
|
print(f"Loaded {len(cache.get('queries', []))} cached gate decisions")
|
|
|
return cache
|
|
|
except Exception as e:
|
|
|
print(f"Error loading cache: {e}")
|
|
|
|
|
|
|
|
|
return {"queries": {}}
|
|
|
|
|
|
def _save_cache(self) -> None:
|
|
|
"""Save decision cache to file"""
|
|
|
try:
|
|
|
with open(self.cache_file, "w") as f:
|
|
|
json.dump(self.decision_cache, f, indent=2)
|
|
|
except Exception as e:
|
|
|
print(f"Error saving cache: {e}")
|
|
|
|
|
|
def _query_hash(self, query: str) -> str:
|
|
|
"""
|
|
|
Create simple hash for query caching
|
|
|
|
|
|
Args:
|
|
|
query: Query string
|
|
|
|
|
|
Returns:
|
|
|
Query hash
|
|
|
"""
|
|
|
|
|
|
import hashlib
|
|
|
return hashlib.md5(query.strip().lower().encode()).hexdigest()
|
|
|
|
|
|
def _heuristic_features(self, query: str) -> Dict[str, float]:
|
|
|
"""
|
|
|
Extract heuristic features from query
|
|
|
|
|
|
Args:
|
|
|
query: Query string
|
|
|
|
|
|
Returns:
|
|
|
Dictionary of feature values
|
|
|
"""
|
|
|
|
|
|
query_lower = query.lower()
|
|
|
|
|
|
|
|
|
length = len(query)
|
|
|
norm_length = min(1.0, length / 200.0)
|
|
|
|
|
|
|
|
|
factual_indicators = [
|
|
|
"what is", "when did", "where is", "who is",
|
|
|
"which", "how many", "list the", "tell me about",
|
|
|
"explain", "define"
|
|
|
]
|
|
|
has_factual = any(indicator in query_lower for indicator in factual_indicators)
|
|
|
|
|
|
|
|
|
time_indicators = [
|
|
|
"recent", "latest", "current", "today", "now",
|
|
|
"this week", "this month", "this year",
|
|
|
"2023", "2024", "2025"
|
|
|
]
|
|
|
has_time = any(indicator in query_lower for indicator in time_indicators)
|
|
|
|
|
|
|
|
|
|
|
|
words = query.split()
|
|
|
capitalized_words = [w for w in words if w[0:1].isupper()]
|
|
|
entity_ratio = len(capitalized_words) / max(1, len(words))
|
|
|
|
|
|
|
|
|
complex_indicators = [
|
|
|
"why", "how does", "explain", "compare", "contrast",
|
|
|
"what if", "analyze", "evaluate", "synthesize"
|
|
|
]
|
|
|
complexity_score = sum(indicator in query_lower for indicator in complex_indicators) / 3.0
|
|
|
complexity_score = min(1.0, complexity_score)
|
|
|
|
|
|
|
|
|
return {
|
|
|
"length": norm_length,
|
|
|
"has_factual": float(has_factual),
|
|
|
"has_time": float(has_time),
|
|
|
"entity_ratio": entity_ratio,
|
|
|
"complexity": complexity_score
|
|
|
}
|
|
|
|
|
|
def _heuristic_decision(self, features: Dict[str, float]) -> Tuple[bool, float]:
|
|
|
"""
|
|
|
Make decision based on heuristic features
|
|
|
|
|
|
Args:
|
|
|
features: Feature dictionary
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (needs_patches, confidence)
|
|
|
"""
|
|
|
|
|
|
weights = {
|
|
|
"length": 0.1,
|
|
|
"has_factual": 0.3,
|
|
|
"has_time": 0.4,
|
|
|
"entity_ratio": 0.1,
|
|
|
"complexity": -0.1
|
|
|
}
|
|
|
|
|
|
|
|
|
score = sum(features[f] * weights[f] for f in features)
|
|
|
|
|
|
|
|
|
score = max(0.0, min(1.0, score))
|
|
|
|
|
|
|
|
|
needs_patches = score >= self.gate_threshold
|
|
|
|
|
|
return needs_patches, score
|
|
|
|
|
|
def _model_decision(self, query: str) -> Tuple[bool, float]:
|
|
|
"""
|
|
|
Ask the model to decide if the query needs up-to-date knowledge
|
|
|
|
|
|
Args:
|
|
|
query: Query string
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (needs_patches, confidence)
|
|
|
"""
|
|
|
if not self.gate_model:
|
|
|
|
|
|
features = self._heuristic_features(query)
|
|
|
return self._heuristic_decision(features)
|
|
|
|
|
|
|
|
|
prompt = f"""<s>[INST] Please analyze this question and determine if it requires the most up-to-date knowledge to answer correctly.
|
|
|
Respond with only a single word: 'YES' if up-to-date knowledge is needed, or 'NO' if it can be answered with general knowledge.
|
|
|
|
|
|
Question: "{query}"
|
|
|
|
|
|
Requires up-to-date knowledge? [/INST]"""
|
|
|
|
|
|
|
|
|
completion = self.gate_model.create_completion(
|
|
|
prompt=prompt,
|
|
|
max_tokens=5,
|
|
|
temperature=0.1,
|
|
|
stop=["</s>", "\n"]
|
|
|
)
|
|
|
|
|
|
|
|
|
response_text = completion.get("choices", [{}])[0].get("text", "").strip().upper()
|
|
|
|
|
|
|
|
|
confidence = 0.7
|
|
|
|
|
|
|
|
|
needs_patches = "YES" in response_text
|
|
|
|
|
|
return needs_patches, confidence
|
|
|
|
|
|
def should_use_patches(self, query: str, use_model: bool = True) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Determine if the query requires up-to-date knowledge patches
|
|
|
|
|
|
Args:
|
|
|
query: Query string
|
|
|
use_model: Whether to use model for decision (vs pure heuristics)
|
|
|
|
|
|
Returns:
|
|
|
Decision dictionary with keys:
|
|
|
- needs_patches: Boolean decision
|
|
|
- confidence: Confidence score (0.0-1.0)
|
|
|
- method: Decision method used
|
|
|
- features: Feature values if heuristic method used
|
|
|
"""
|
|
|
|
|
|
query_hash = self._query_hash(query)
|
|
|
if query_hash in self.decision_cache.get("queries", {}):
|
|
|
cached = self.decision_cache["queries"][query_hash]
|
|
|
cached["from_cache"] = True
|
|
|
return cached
|
|
|
|
|
|
|
|
|
features = self._heuristic_features(query)
|
|
|
|
|
|
|
|
|
if use_model and self.gate_model:
|
|
|
needs_patches, confidence = self._model_decision(query)
|
|
|
method = "model"
|
|
|
else:
|
|
|
needs_patches, confidence = self._heuristic_decision(features)
|
|
|
method = "heuristic"
|
|
|
|
|
|
|
|
|
decision = {
|
|
|
"needs_patches": needs_patches,
|
|
|
"confidence": confidence,
|
|
|
"method": method,
|
|
|
"features": features,
|
|
|
"from_cache": False
|
|
|
}
|
|
|
|
|
|
|
|
|
self.decision_cache.setdefault("queries", {})[query_hash] = decision
|
|
|
self._save_cache()
|
|
|
|
|
|
return decision
|
|
|
|
|
|
def main():
|
|
|
"""Test difficulty gate"""
|
|
|
|
|
|
model_dir = "models/slow"
|
|
|
model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
|
|
|
|
|
|
if not model_files:
|
|
|
print(f"No GGUF models found in {model_dir}")
|
|
|
return
|
|
|
|
|
|
model_path = os.path.join(model_dir, model_files[0])
|
|
|
print(f"Using model: {model_path}")
|
|
|
|
|
|
|
|
|
gate = DifficultyGate(model_path=model_path)
|
|
|
|
|
|
|
|
|
test_queries = [
|
|
|
"What is the capital of France?",
|
|
|
"Who is the current president of the United States?",
|
|
|
"Explain the theory of relativity",
|
|
|
"What are the latest developments in the conflict in Ukraine?",
|
|
|
"Who won the most recent Super Bowl?",
|
|
|
"How do I write a for loop in Python?"
|
|
|
]
|
|
|
|
|
|
for query in test_queries:
|
|
|
|
|
|
decision = gate.should_use_patches(query, use_model=False)
|
|
|
print(f"\nQuery: {query}")
|
|
|
print(f"Heuristic Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
|
|
|
print(f"Features: {decision['features']}")
|
|
|
|
|
|
|
|
|
if gate.gate_model:
|
|
|
decision = gate.should_use_patches(query, use_model=True)
|
|
|
print(f"Model Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |