Upload runtime\difficulty_gate.py with huggingface_hub
Browse files- runtime//difficulty_gate.py +319 -0
runtime//difficulty_gate.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Difficulty Gate for ContinuumAgent Project
|
| 4 |
+
Smart routing system to determine whether to use patches based on query complexity
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 10 |
+
import numpy as np
|
| 11 |
+
from llama_cpp import Llama
|
| 12 |
+
|
| 13 |
+
class DifficultyGate:
|
| 14 |
+
"""
|
| 15 |
+
Smart routing system to determine whether to use patches based on query complexity
|
| 16 |
+
Uses a simple heuristic approach for initial implementation, can be replaced with a learned classifier
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
model_path: str,
|
| 21 |
+
gate_threshold: float = 0.7,
|
| 22 |
+
cache_dir: str = "models/gates",
|
| 23 |
+
n_gpu_layers: int = 0):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the difficulty gate
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_path: Path to GGUF model file
|
| 29 |
+
gate_threshold: Threshold for routing to patched model (0.0-1.0)
|
| 30 |
+
cache_dir: Directory for caching gate decisions
|
| 31 |
+
n_gpu_layers: Number of layers to offload to GPU
|
| 32 |
+
"""
|
| 33 |
+
self.model_path = model_path
|
| 34 |
+
self.gate_threshold = gate_threshold
|
| 35 |
+
self.cache_dir = cache_dir
|
| 36 |
+
self.n_gpu_layers = n_gpu_layers
|
| 37 |
+
|
| 38 |
+
# Create cache directory if it doesn't exist
|
| 39 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
# Cache file path
|
| 42 |
+
self.cache_file = os.path.join(cache_dir, "gate_cache.json")
|
| 43 |
+
|
| 44 |
+
# Load cache
|
| 45 |
+
self.decision_cache = self._load_cache()
|
| 46 |
+
|
| 47 |
+
# Initialize gate model (small context for efficiency)
|
| 48 |
+
self._init_gate_model()
|
| 49 |
+
|
| 50 |
+
def _init_gate_model(self) -> None:
|
| 51 |
+
"""Initialize small gate model"""
|
| 52 |
+
try:
|
| 53 |
+
print(f"Loading gate model from {self.model_path}...")
|
| 54 |
+
self.gate_model = Llama(
|
| 55 |
+
model_path=self.model_path,
|
| 56 |
+
n_gpu_layers=self.n_gpu_layers,
|
| 57 |
+
n_ctx=512 # Small context for efficiency
|
| 58 |
+
)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading gate model: {e}")
|
| 61 |
+
self.gate_model = None
|
| 62 |
+
|
| 63 |
+
def _load_cache(self) -> Dict[str, Any]:
|
| 64 |
+
"""
|
| 65 |
+
Load decision cache from file
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Cache dictionary
|
| 69 |
+
"""
|
| 70 |
+
if os.path.exists(self.cache_file):
|
| 71 |
+
try:
|
| 72 |
+
with open(self.cache_file, "r") as f:
|
| 73 |
+
cache = json.load(f)
|
| 74 |
+
print(f"Loaded {len(cache.get('queries', []))} cached gate decisions")
|
| 75 |
+
return cache
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Error loading cache: {e}")
|
| 78 |
+
|
| 79 |
+
# Return empty cache
|
| 80 |
+
return {"queries": {}}
|
| 81 |
+
|
| 82 |
+
def _save_cache(self) -> None:
|
| 83 |
+
"""Save decision cache to file"""
|
| 84 |
+
try:
|
| 85 |
+
with open(self.cache_file, "w") as f:
|
| 86 |
+
json.dump(self.decision_cache, f, indent=2)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error saving cache: {e}")
|
| 89 |
+
|
| 90 |
+
def _query_hash(self, query: str) -> str:
|
| 91 |
+
"""
|
| 92 |
+
Create simple hash for query caching
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
query: Query string
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Query hash
|
| 99 |
+
"""
|
| 100 |
+
# Simple hash method, can be improved
|
| 101 |
+
import hashlib
|
| 102 |
+
return hashlib.md5(query.strip().lower().encode()).hexdigest()
|
| 103 |
+
|
| 104 |
+
def _heuristic_features(self, query: str) -> Dict[str, float]:
|
| 105 |
+
"""
|
| 106 |
+
Extract heuristic features from query
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
query: Query string
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Dictionary of feature values
|
| 113 |
+
"""
|
| 114 |
+
# Lowercase query for consistent processing
|
| 115 |
+
query_lower = query.lower()
|
| 116 |
+
|
| 117 |
+
# Feature 1: Query length
|
| 118 |
+
length = len(query)
|
| 119 |
+
norm_length = min(1.0, length / 200.0) # Normalize to 0-1 (capped at 200 chars)
|
| 120 |
+
|
| 121 |
+
# Feature 2: Presence of factual question indicators
|
| 122 |
+
factual_indicators = [
|
| 123 |
+
"what is", "when did", "where is", "who is",
|
| 124 |
+
"which", "how many", "list the", "tell me about",
|
| 125 |
+
"explain", "define"
|
| 126 |
+
]
|
| 127 |
+
has_factual = any(indicator in query_lower for indicator in factual_indicators)
|
| 128 |
+
|
| 129 |
+
# Feature 3: Presence of time indicators (recency)
|
| 130 |
+
time_indicators = [
|
| 131 |
+
"recent", "latest", "current", "today", "now",
|
| 132 |
+
"this week", "this month", "this year",
|
| 133 |
+
"2023", "2024", "2025" # Add current years
|
| 134 |
+
]
|
| 135 |
+
has_time = any(indicator in query_lower for indicator in time_indicators)
|
| 136 |
+
|
| 137 |
+
# Feature 4: Entity recognition (simplified)
|
| 138 |
+
# Check for capitalized terms that may indicate named entities
|
| 139 |
+
words = query.split()
|
| 140 |
+
capitalized_words = [w for w in words if w[0:1].isupper()]
|
| 141 |
+
entity_ratio = len(capitalized_words) / max(1, len(words))
|
| 142 |
+
|
| 143 |
+
# Feature 5: Question complexity based on interrogative words
|
| 144 |
+
complex_indicators = [
|
| 145 |
+
"why", "how does", "explain", "compare", "contrast",
|
| 146 |
+
"what if", "analyze", "evaluate", "synthesize"
|
| 147 |
+
]
|
| 148 |
+
complexity_score = sum(indicator in query_lower for indicator in complex_indicators) / 3.0
|
| 149 |
+
complexity_score = min(1.0, complexity_score)
|
| 150 |
+
|
| 151 |
+
# Return features
|
| 152 |
+
return {
|
| 153 |
+
"length": norm_length,
|
| 154 |
+
"has_factual": float(has_factual),
|
| 155 |
+
"has_time": float(has_time),
|
| 156 |
+
"entity_ratio": entity_ratio,
|
| 157 |
+
"complexity": complexity_score
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
def _heuristic_decision(self, features: Dict[str, float]) -> Tuple[bool, float]:
|
| 161 |
+
"""
|
| 162 |
+
Make decision based on heuristic features
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
features: Feature dictionary
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Tuple of (needs_patches, confidence)
|
| 169 |
+
"""
|
| 170 |
+
# Weights for different features
|
| 171 |
+
weights = {
|
| 172 |
+
"length": 0.1,
|
| 173 |
+
"has_factual": 0.3,
|
| 174 |
+
"has_time": 0.4, # Highest weight for time indicators
|
| 175 |
+
"entity_ratio": 0.1,
|
| 176 |
+
"complexity": -0.1 # Negative weight - complex reasoning queries may not need patches
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
# Calculate weighted score
|
| 180 |
+
score = sum(features[f] * weights[f] for f in features)
|
| 181 |
+
|
| 182 |
+
# Normalize to 0-1 range
|
| 183 |
+
score = max(0.0, min(1.0, score))
|
| 184 |
+
|
| 185 |
+
# Decision based on threshold
|
| 186 |
+
needs_patches = score >= self.gate_threshold
|
| 187 |
+
|
| 188 |
+
return needs_patches, score
|
| 189 |
+
|
| 190 |
+
def _model_decision(self, query: str) -> Tuple[bool, float]:
|
| 191 |
+
"""
|
| 192 |
+
Ask the model to decide if the query needs up-to-date knowledge
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
query: Query string
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tuple of (needs_patches, confidence)
|
| 199 |
+
"""
|
| 200 |
+
if not self.gate_model:
|
| 201 |
+
# Fall back to heuristic if model not available
|
| 202 |
+
features = self._heuristic_features(query)
|
| 203 |
+
return self._heuristic_decision(features)
|
| 204 |
+
|
| 205 |
+
# Prompt for model
|
| 206 |
+
prompt = f"""<s>[INST] Please analyze this question and determine if it requires the most up-to-date knowledge to answer correctly.
|
| 207 |
+
Respond with only a single word: 'YES' if up-to-date knowledge is needed, or 'NO' if it can be answered with general knowledge.
|
| 208 |
+
|
| 209 |
+
Question: "{query}"
|
| 210 |
+
|
| 211 |
+
Requires up-to-date knowledge? [/INST]"""
|
| 212 |
+
|
| 213 |
+
# Generate completion
|
| 214 |
+
completion = self.gate_model.create_completion(
|
| 215 |
+
prompt=prompt,
|
| 216 |
+
max_tokens=5,
|
| 217 |
+
temperature=0.1, # Low temperature for consistent results
|
| 218 |
+
stop=["</s>", "\n"]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Extract response
|
| 222 |
+
response_text = completion.get("choices", [{}])[0].get("text", "").strip().upper()
|
| 223 |
+
|
| 224 |
+
# Calculate confidence from logprobs if available
|
| 225 |
+
confidence = 0.7 # Default confidence
|
| 226 |
+
|
| 227 |
+
# Decision based on response
|
| 228 |
+
needs_patches = "YES" in response_text
|
| 229 |
+
|
| 230 |
+
return needs_patches, confidence
|
| 231 |
+
|
| 232 |
+
def should_use_patches(self, query: str, use_model: bool = True) -> Dict[str, Any]:
|
| 233 |
+
"""
|
| 234 |
+
Determine if the query requires up-to-date knowledge patches
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
query: Query string
|
| 238 |
+
use_model: Whether to use model for decision (vs pure heuristics)
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Decision dictionary with keys:
|
| 242 |
+
- needs_patches: Boolean decision
|
| 243 |
+
- confidence: Confidence score (0.0-1.0)
|
| 244 |
+
- method: Decision method used
|
| 245 |
+
- features: Feature values if heuristic method used
|
| 246 |
+
"""
|
| 247 |
+
# Check cache first
|
| 248 |
+
query_hash = self._query_hash(query)
|
| 249 |
+
if query_hash in self.decision_cache.get("queries", {}):
|
| 250 |
+
cached = self.decision_cache["queries"][query_hash]
|
| 251 |
+
cached["from_cache"] = True
|
| 252 |
+
return cached
|
| 253 |
+
|
| 254 |
+
# Extract features
|
| 255 |
+
features = self._heuristic_features(query)
|
| 256 |
+
|
| 257 |
+
# Make decision
|
| 258 |
+
if use_model and self.gate_model:
|
| 259 |
+
needs_patches, confidence = self._model_decision(query)
|
| 260 |
+
method = "model"
|
| 261 |
+
else:
|
| 262 |
+
needs_patches, confidence = self._heuristic_decision(features)
|
| 263 |
+
method = "heuristic"
|
| 264 |
+
|
| 265 |
+
# Create decision
|
| 266 |
+
decision = {
|
| 267 |
+
"needs_patches": needs_patches,
|
| 268 |
+
"confidence": confidence,
|
| 269 |
+
"method": method,
|
| 270 |
+
"features": features,
|
| 271 |
+
"from_cache": False
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
# Cache decision
|
| 275 |
+
self.decision_cache.setdefault("queries", {})[query_hash] = decision
|
| 276 |
+
self._save_cache()
|
| 277 |
+
|
| 278 |
+
return decision
|
| 279 |
+
|
| 280 |
+
def main():
|
| 281 |
+
"""Test difficulty gate"""
|
| 282 |
+
# Find model path
|
| 283 |
+
model_dir = "models/slow"
|
| 284 |
+
model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
|
| 285 |
+
|
| 286 |
+
if not model_files:
|
| 287 |
+
print(f"No GGUF models found in {model_dir}")
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
model_path = os.path.join(model_dir, model_files[0])
|
| 291 |
+
print(f"Using model: {model_path}")
|
| 292 |
+
|
| 293 |
+
# Initialize gate
|
| 294 |
+
gate = DifficultyGate(model_path=model_path)
|
| 295 |
+
|
| 296 |
+
# Test queries
|
| 297 |
+
test_queries = [
|
| 298 |
+
"What is the capital of France?",
|
| 299 |
+
"Who is the current president of the United States?",
|
| 300 |
+
"Explain the theory of relativity",
|
| 301 |
+
"What are the latest developments in the conflict in Ukraine?",
|
| 302 |
+
"Who won the most recent Super Bowl?",
|
| 303 |
+
"How do I write a for loop in Python?"
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
for query in test_queries:
|
| 307 |
+
# Test heuristic decision
|
| 308 |
+
decision = gate.should_use_patches(query, use_model=False)
|
| 309 |
+
print(f"\nQuery: {query}")
|
| 310 |
+
print(f"Heuristic Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
|
| 311 |
+
print(f"Features: {decision['features']}")
|
| 312 |
+
|
| 313 |
+
# Test model decision if model is available
|
| 314 |
+
if gate.gate_model:
|
| 315 |
+
decision = gate.should_use_patches(query, use_model=True)
|
| 316 |
+
print(f"Model Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
main()
|