deasdutta commited on
Commit
8838cce
·
verified ·
1 Parent(s): 1932054

Upload runtime\difficulty_gate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. runtime//difficulty_gate.py +319 -0
runtime//difficulty_gate.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Difficulty Gate for ContinuumAgent Project
4
+ Smart routing system to determine whether to use patches based on query complexity
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Dict, Any, List, Optional, Tuple
10
+ import numpy as np
11
+ from llama_cpp import Llama
12
+
13
+ class DifficultyGate:
14
+ """
15
+ Smart routing system to determine whether to use patches based on query complexity
16
+ Uses a simple heuristic approach for initial implementation, can be replaced with a learned classifier
17
+ """
18
+
19
+ def __init__(self,
20
+ model_path: str,
21
+ gate_threshold: float = 0.7,
22
+ cache_dir: str = "models/gates",
23
+ n_gpu_layers: int = 0):
24
+ """
25
+ Initialize the difficulty gate
26
+
27
+ Args:
28
+ model_path: Path to GGUF model file
29
+ gate_threshold: Threshold for routing to patched model (0.0-1.0)
30
+ cache_dir: Directory for caching gate decisions
31
+ n_gpu_layers: Number of layers to offload to GPU
32
+ """
33
+ self.model_path = model_path
34
+ self.gate_threshold = gate_threshold
35
+ self.cache_dir = cache_dir
36
+ self.n_gpu_layers = n_gpu_layers
37
+
38
+ # Create cache directory if it doesn't exist
39
+ os.makedirs(cache_dir, exist_ok=True)
40
+
41
+ # Cache file path
42
+ self.cache_file = os.path.join(cache_dir, "gate_cache.json")
43
+
44
+ # Load cache
45
+ self.decision_cache = self._load_cache()
46
+
47
+ # Initialize gate model (small context for efficiency)
48
+ self._init_gate_model()
49
+
50
+ def _init_gate_model(self) -> None:
51
+ """Initialize small gate model"""
52
+ try:
53
+ print(f"Loading gate model from {self.model_path}...")
54
+ self.gate_model = Llama(
55
+ model_path=self.model_path,
56
+ n_gpu_layers=self.n_gpu_layers,
57
+ n_ctx=512 # Small context for efficiency
58
+ )
59
+ except Exception as e:
60
+ print(f"Error loading gate model: {e}")
61
+ self.gate_model = None
62
+
63
+ def _load_cache(self) -> Dict[str, Any]:
64
+ """
65
+ Load decision cache from file
66
+
67
+ Returns:
68
+ Cache dictionary
69
+ """
70
+ if os.path.exists(self.cache_file):
71
+ try:
72
+ with open(self.cache_file, "r") as f:
73
+ cache = json.load(f)
74
+ print(f"Loaded {len(cache.get('queries', []))} cached gate decisions")
75
+ return cache
76
+ except Exception as e:
77
+ print(f"Error loading cache: {e}")
78
+
79
+ # Return empty cache
80
+ return {"queries": {}}
81
+
82
+ def _save_cache(self) -> None:
83
+ """Save decision cache to file"""
84
+ try:
85
+ with open(self.cache_file, "w") as f:
86
+ json.dump(self.decision_cache, f, indent=2)
87
+ except Exception as e:
88
+ print(f"Error saving cache: {e}")
89
+
90
+ def _query_hash(self, query: str) -> str:
91
+ """
92
+ Create simple hash for query caching
93
+
94
+ Args:
95
+ query: Query string
96
+
97
+ Returns:
98
+ Query hash
99
+ """
100
+ # Simple hash method, can be improved
101
+ import hashlib
102
+ return hashlib.md5(query.strip().lower().encode()).hexdigest()
103
+
104
+ def _heuristic_features(self, query: str) -> Dict[str, float]:
105
+ """
106
+ Extract heuristic features from query
107
+
108
+ Args:
109
+ query: Query string
110
+
111
+ Returns:
112
+ Dictionary of feature values
113
+ """
114
+ # Lowercase query for consistent processing
115
+ query_lower = query.lower()
116
+
117
+ # Feature 1: Query length
118
+ length = len(query)
119
+ norm_length = min(1.0, length / 200.0) # Normalize to 0-1 (capped at 200 chars)
120
+
121
+ # Feature 2: Presence of factual question indicators
122
+ factual_indicators = [
123
+ "what is", "when did", "where is", "who is",
124
+ "which", "how many", "list the", "tell me about",
125
+ "explain", "define"
126
+ ]
127
+ has_factual = any(indicator in query_lower for indicator in factual_indicators)
128
+
129
+ # Feature 3: Presence of time indicators (recency)
130
+ time_indicators = [
131
+ "recent", "latest", "current", "today", "now",
132
+ "this week", "this month", "this year",
133
+ "2023", "2024", "2025" # Add current years
134
+ ]
135
+ has_time = any(indicator in query_lower for indicator in time_indicators)
136
+
137
+ # Feature 4: Entity recognition (simplified)
138
+ # Check for capitalized terms that may indicate named entities
139
+ words = query.split()
140
+ capitalized_words = [w for w in words if w[0:1].isupper()]
141
+ entity_ratio = len(capitalized_words) / max(1, len(words))
142
+
143
+ # Feature 5: Question complexity based on interrogative words
144
+ complex_indicators = [
145
+ "why", "how does", "explain", "compare", "contrast",
146
+ "what if", "analyze", "evaluate", "synthesize"
147
+ ]
148
+ complexity_score = sum(indicator in query_lower for indicator in complex_indicators) / 3.0
149
+ complexity_score = min(1.0, complexity_score)
150
+
151
+ # Return features
152
+ return {
153
+ "length": norm_length,
154
+ "has_factual": float(has_factual),
155
+ "has_time": float(has_time),
156
+ "entity_ratio": entity_ratio,
157
+ "complexity": complexity_score
158
+ }
159
+
160
+ def _heuristic_decision(self, features: Dict[str, float]) -> Tuple[bool, float]:
161
+ """
162
+ Make decision based on heuristic features
163
+
164
+ Args:
165
+ features: Feature dictionary
166
+
167
+ Returns:
168
+ Tuple of (needs_patches, confidence)
169
+ """
170
+ # Weights for different features
171
+ weights = {
172
+ "length": 0.1,
173
+ "has_factual": 0.3,
174
+ "has_time": 0.4, # Highest weight for time indicators
175
+ "entity_ratio": 0.1,
176
+ "complexity": -0.1 # Negative weight - complex reasoning queries may not need patches
177
+ }
178
+
179
+ # Calculate weighted score
180
+ score = sum(features[f] * weights[f] for f in features)
181
+
182
+ # Normalize to 0-1 range
183
+ score = max(0.0, min(1.0, score))
184
+
185
+ # Decision based on threshold
186
+ needs_patches = score >= self.gate_threshold
187
+
188
+ return needs_patches, score
189
+
190
+ def _model_decision(self, query: str) -> Tuple[bool, float]:
191
+ """
192
+ Ask the model to decide if the query needs up-to-date knowledge
193
+
194
+ Args:
195
+ query: Query string
196
+
197
+ Returns:
198
+ Tuple of (needs_patches, confidence)
199
+ """
200
+ if not self.gate_model:
201
+ # Fall back to heuristic if model not available
202
+ features = self._heuristic_features(query)
203
+ return self._heuristic_decision(features)
204
+
205
+ # Prompt for model
206
+ prompt = f"""<s>[INST] Please analyze this question and determine if it requires the most up-to-date knowledge to answer correctly.
207
+ Respond with only a single word: 'YES' if up-to-date knowledge is needed, or 'NO' if it can be answered with general knowledge.
208
+
209
+ Question: "{query}"
210
+
211
+ Requires up-to-date knowledge? [/INST]"""
212
+
213
+ # Generate completion
214
+ completion = self.gate_model.create_completion(
215
+ prompt=prompt,
216
+ max_tokens=5,
217
+ temperature=0.1, # Low temperature for consistent results
218
+ stop=["</s>", "\n"]
219
+ )
220
+
221
+ # Extract response
222
+ response_text = completion.get("choices", [{}])[0].get("text", "").strip().upper()
223
+
224
+ # Calculate confidence from logprobs if available
225
+ confidence = 0.7 # Default confidence
226
+
227
+ # Decision based on response
228
+ needs_patches = "YES" in response_text
229
+
230
+ return needs_patches, confidence
231
+
232
+ def should_use_patches(self, query: str, use_model: bool = True) -> Dict[str, Any]:
233
+ """
234
+ Determine if the query requires up-to-date knowledge patches
235
+
236
+ Args:
237
+ query: Query string
238
+ use_model: Whether to use model for decision (vs pure heuristics)
239
+
240
+ Returns:
241
+ Decision dictionary with keys:
242
+ - needs_patches: Boolean decision
243
+ - confidence: Confidence score (0.0-1.0)
244
+ - method: Decision method used
245
+ - features: Feature values if heuristic method used
246
+ """
247
+ # Check cache first
248
+ query_hash = self._query_hash(query)
249
+ if query_hash in self.decision_cache.get("queries", {}):
250
+ cached = self.decision_cache["queries"][query_hash]
251
+ cached["from_cache"] = True
252
+ return cached
253
+
254
+ # Extract features
255
+ features = self._heuristic_features(query)
256
+
257
+ # Make decision
258
+ if use_model and self.gate_model:
259
+ needs_patches, confidence = self._model_decision(query)
260
+ method = "model"
261
+ else:
262
+ needs_patches, confidence = self._heuristic_decision(features)
263
+ method = "heuristic"
264
+
265
+ # Create decision
266
+ decision = {
267
+ "needs_patches": needs_patches,
268
+ "confidence": confidence,
269
+ "method": method,
270
+ "features": features,
271
+ "from_cache": False
272
+ }
273
+
274
+ # Cache decision
275
+ self.decision_cache.setdefault("queries", {})[query_hash] = decision
276
+ self._save_cache()
277
+
278
+ return decision
279
+
280
+ def main():
281
+ """Test difficulty gate"""
282
+ # Find model path
283
+ model_dir = "models/slow"
284
+ model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
285
+
286
+ if not model_files:
287
+ print(f"No GGUF models found in {model_dir}")
288
+ return
289
+
290
+ model_path = os.path.join(model_dir, model_files[0])
291
+ print(f"Using model: {model_path}")
292
+
293
+ # Initialize gate
294
+ gate = DifficultyGate(model_path=model_path)
295
+
296
+ # Test queries
297
+ test_queries = [
298
+ "What is the capital of France?",
299
+ "Who is the current president of the United States?",
300
+ "Explain the theory of relativity",
301
+ "What are the latest developments in the conflict in Ukraine?",
302
+ "Who won the most recent Super Bowl?",
303
+ "How do I write a for loop in Python?"
304
+ ]
305
+
306
+ for query in test_queries:
307
+ # Test heuristic decision
308
+ decision = gate.should_use_patches(query, use_model=False)
309
+ print(f"\nQuery: {query}")
310
+ print(f"Heuristic Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
311
+ print(f"Features: {decision['features']}")
312
+
313
+ # Test model decision if model is available
314
+ if gate.gate_model:
315
+ decision = gate.should_use_patches(query, use_model=True)
316
+ print(f"Model Decision: {decision['needs_patches']} (Confidence: {decision['confidence']:.2f})")
317
+
318
+ if __name__ == "__main__":
319
+ main()