teofizzy commited on
Commit
868d1b6
·
1 Parent(s): f81b0fc

Revert "feat: Implement a parallel voting ensemble for LLM selection based on Jaccard similarity, replacing the sequential fallback mechanism."

Browse files

This reverts commit f81b0fced262572d64763a9cb34c07bf38d8cc82.

Files changed (1) hide show
  1. src/load/mshauri_demo.py +50 -114
src/load/mshauri_demo.py CHANGED
@@ -4,7 +4,6 @@ import re
4
  import sys
5
  import io
6
  import time
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
  from contextlib import redirect_stdout
9
  from typing import Any, List, Optional, Mapping
10
 
@@ -84,109 +83,35 @@ CANDIDATE_MODELS = [
84
  "HuggingFaceH4/zephyr-7b-beta", # Old Reliable
85
  ]
86
 
87
-
88
- # --- VOTING ENSEMBLE ---
89
- class VotingLLM:
90
- """
91
- Calls all available LLMs in parallel and selects the response with
92
- the highest peer-consensus score (Jaccard word similarity).
93
-
94
- Models that fail or time out are silently skipped, providing
95
- built-in fallback behavior without a separate fallback chain.
96
- If ALL models fail, raises ValueError so the agent can handle it.
97
- """
98
- def __init__(self, llms: list, timeout: int = 45):
99
- self.llms = llms
100
- self.timeout = timeout
101
-
102
- def invoke(self, prompt, stop=None):
103
- def call_one(llm):
104
- result = llm.invoke(prompt, stop=stop) if stop else llm.invoke(prompt)
105
- text = result if isinstance(result, str) else result.content
106
- return text.strip() if text else None
107
-
108
- responses = []
109
- with ThreadPoolExecutor(max_workers=len(self.llms)) as executor:
110
- futures = {executor.submit(call_one, llm): llm for llm in self.llms}
111
- try:
112
- for future in as_completed(futures, timeout=self.timeout):
113
- llm_name = futures[future].__class__.__name__
114
- try:
115
- result = future.result()
116
- if result:
117
- responses.append(result)
118
- print(f"Vote received: {llm_name}", flush=True)
119
- except Exception as e:
120
- print(f"Voter {llm_name} failed: {str(e)[:80]}", flush=True)
121
- except TimeoutError:
122
- print("Voting timed out. Using responses collected so far.", flush=True)
123
-
124
- if not responses:
125
- raise ValueError("All LLMs failed to respond during voting.")
126
-
127
- if len(responses) == 1:
128
- return responses[0]
129
-
130
- return self._pick_consensus(responses)
131
-
132
- def _pick_consensus(self, responses: list) -> str:
133
- """Returns the response with the highest average Jaccard similarity to all others.
134
- This is the 'centroid' of the group — the most broadly agreed-upon answer.
135
- """
136
- best_score, best_response = -1.0, responses[0]
137
- for i, r1 in enumerate(responses):
138
- words1 = set(r1.lower().split())
139
- scores = []
140
- for j, r2 in enumerate(responses):
141
- if i == j:
142
- continue
143
- words2 = set(r2.lower().split())
144
- union = len(words1 | words2)
145
- scores.append(len(words1 & words2) / union if union else 0.0)
146
- avg = sum(scores) / len(scores) if scores else 0.0
147
- if avg > best_score:
148
- best_score, best_response = avg, r1
149
- print(f"Consensus winner: score={best_score:.2f}, voters={len(responses)}", flush=True)
150
- return best_response
151
-
152
  def get_robust_llm():
153
- """Builds a voting ensemble LLM from all available providers.
154
-
155
- All available models vote simultaneously on every query. The response
156
- with the highest peer-consensus score (Jaccard word similarity) wins.
157
- Models that fail during a vote are silently skipped — providing
158
- built-in fallback behavior.
159
 
160
- Priority / collection order:
161
  1. Hugging Face (Qwen 72B) - requires HF_TOKEN
162
  2. Groq (Llama 70B) - requires GROQ_API_KEY
163
  3. Gemini (1.5 Flash) - requires GEMINI_API_KEY
164
- 4. Local Ollama (Qwen 7B) - always included
165
-
166
- Returns:
167
- (robust_llm, base_llm)
168
- robust_llm: VotingLLM for the agent brain (or single model if only 1 available)
169
- base_llm: Plain highest-priority model for SQLDatabaseToolkit
170
  """
171
- available_llms = [] # All working models, in priority order
 
172
 
173
- # 1. HuggingFace test candidate models until one responds
174
  hf_token = os.getenv("HF_TOKEN")
175
  if hf_token:
176
- print("HF Token found. Testing candidate models...", flush=True)
177
  for model_id in CANDIDATE_MODELS:
178
- print(f" Trying: {model_id}...", flush=True)
179
  try:
180
- candidate = HuggingFaceChat(repo_id=model_id, hf_token=hf_token, temperature=0.1)
181
- candidate.invoke("Ping")
182
- available_llms.append(candidate)
183
- print(f"HF voter ready: {model_id}", flush=True)
184
  break
185
  except Exception as e:
186
- print(f" Failed {model_id}: {str(e)[:100]}...", flush=True)
187
  time.sleep(0.5)
188
 
189
- # 2. Groq (Llama-3.3-70B)
190
  groq_key = os.getenv("GROQ_API_KEY")
191
  if groq_key:
192
  groq_llm = ChatGroq(
@@ -194,10 +119,14 @@ def get_robust_llm():
194
  temperature=0.1,
195
  api_key=groq_key,
196
  )
197
- available_llms.append(groq_llm)
198
- print("Groq voter ready.", flush=True)
199
-
200
- # 3. Gemini (1.5 Flash)
 
 
 
 
201
  gemini_key = os.getenv("GEMINI_API_KEY")
202
  if gemini_key:
203
  gemini_llm = ChatGoogleGenerativeAI(
@@ -205,29 +134,36 @@ def get_robust_llm():
205
  temperature=0.1,
206
  google_api_key=gemini_key,
207
  )
208
- available_llms.append(gemini_llm)
209
- print("Gemini voter ready.", flush=True)
210
-
211
- # 4. Local Ollama — always included as the guaranteed baseline voter
 
 
 
 
212
  local_llm = ChatOllama(model="qwen2.5:7b", temperature=0)
213
- available_llms.append(local_llm)
214
- print("Ollama voter ready.", flush=True)
215
-
216
- if not available_llms:
217
- return None, None
218
-
219
- # base_llm: highest-priority plain model for SQLDatabaseToolkit
220
- # (VotingLLM is not a LangChain BaseLanguageModel and cannot be passed to the Toolkit)
221
- base_llm = available_llms[0]
222
-
223
- if len(available_llms) == 1:
224
- print("Single voter mode (only Ollama available).", flush=True)
225
- return base_llm, base_llm
 
 
 
 
226
 
227
- print(f"Voting ensemble active: {len(available_llms)} models will collaborate.", flush=True)
228
- return VotingLLM(available_llms), base_llm
229
 
230
- # --- CLASS FOR 'Tool' ---
231
  class SimpleTool:
232
  """A simple wrapper to replace langchain.tools.Tool"""
233
  def __init__(self, name, func, description):
@@ -255,7 +191,7 @@ class PythonREPLTool(SimpleTool):
255
  except Exception as e:
256
  return f"Error executing code: {e}"
257
 
258
- # --- CLASS FOR THE AGENT ---
259
  class SimpleReActAgent:
260
  """A manual ReAct loop that doesn't rely on langchain.agents"""
261
  def __init__(self, llm, tools, verbose=True):
 
4
  import sys
5
  import io
6
  import time
 
7
  from contextlib import redirect_stdout
8
  from typing import Any, List, Optional, Mapping
9
 
 
83
  "HuggingFaceH4/zephyr-7b-beta", # Old Reliable
84
  ]
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def get_robust_llm():
87
+ """Builds an LLM with a resilient fallback cascade.
 
 
 
 
 
88
 
89
+ Priority order:
90
  1. Hugging Face (Qwen 72B) - requires HF_TOKEN
91
  2. Groq (Llama 70B) - requires GROQ_API_KEY
92
  3. Gemini (1.5 Flash) - requires GEMINI_API_KEY
93
+ 4. Local Ollama (Qwen 7B) - always available
 
 
 
 
 
94
  """
95
+ llm = None
96
+ fallbacks = []
97
 
98
+ # PRIMARY: Hugging Face (Qwen 72B)
99
  hf_token = os.getenv("HF_TOKEN")
100
  if hf_token:
101
+ print("HF Token found. Testing models for Primary LLM...", flush=True)
102
  for model_id in CANDIDATE_MODELS:
103
+ print(f"Trying HF model: {model_id}...", flush=True)
104
  try:
105
+ candidate_llm = HuggingFaceChat(repo_id=model_id, hf_token=hf_token, temperature=0.1)
106
+ candidate_llm.invoke("Ping") # Test connection
107
+ llm = candidate_llm
108
+ print(f"Primary LLM: Hugging Face ({model_id})", flush=True)
109
  break
110
  except Exception as e:
111
+ print(f"Failed {model_id}: {str(e)[:100]}...", flush=True)
112
  time.sleep(0.5)
113
 
114
+ # FIRST FALLBACK: Groq (Llama-3.3-70B)
115
  groq_key = os.getenv("GROQ_API_KEY")
116
  if groq_key:
117
  groq_llm = ChatGroq(
 
119
  temperature=0.1,
120
  api_key=groq_key,
121
  )
122
+ if llm is None:
123
+ llm = groq_llm
124
+ print("Primary LLM: Groq (Llama 70B)", flush=True)
125
+ else:
126
+ fallbacks.append(groq_llm)
127
+ print("Added Fallback 1: Groq", flush=True)
128
+
129
+ # SECOND FALLBACK: Gemini (1.5 Flash)
130
  gemini_key = os.getenv("GEMINI_API_KEY")
131
  if gemini_key:
132
  gemini_llm = ChatGoogleGenerativeAI(
 
134
  temperature=0.1,
135
  google_api_key=gemini_key,
136
  )
137
+ if llm is None:
138
+ llm = gemini_llm
139
+ print("Primary LLM: Gemini (1.5 Flash)", flush=True)
140
+ else:
141
+ fallbacks.append(gemini_llm)
142
+ print("Added Fallback 2: Gemini", flush=True)
143
+
144
+ # FINAL FALLBACK: Local Ollama (Qwen 7B)
145
  local_llm = ChatOllama(model="qwen2.5:7b", temperature=0)
146
+ if llm is None:
147
+ llm = local_llm
148
+ print("Primary LLM: Local Ollama (Qwen 7B)", flush=True)
149
+ else:
150
+ fallbacks.append(local_llm)
151
+ print("Added Final Fallback: Local Ollama", flush=True)
152
+
153
+ # Bind fallbacks so LangChain auto-routes on failure
154
+ if fallbacks and hasattr(llm, "with_fallbacks"):
155
+ try:
156
+ # Langchain handles the coercion between LLM and ChatModel types natively
157
+ # when using string prompts.
158
+ robust_llm = llm.with_fallbacks(fallbacks)
159
+ return robust_llm, llm
160
+ except Exception as e:
161
+ print(f"Warning: Fallback binding failed: {e}. Returning base model.", flush=True)
162
+ return llm, llm
163
 
164
+ return llm, llm
 
165
 
166
+ # --- 1. REPLACEMENT CLASS FOR 'Tool' ---
167
  class SimpleTool:
168
  """A simple wrapper to replace langchain.tools.Tool"""
169
  def __init__(self, name, func, description):
 
191
  except Exception as e:
192
  return f"Error executing code: {e}"
193
 
194
+ # --- 2. CLASS FOR THE AGENT ---
195
  class SimpleReActAgent:
196
  """A manual ReAct loop that doesn't rely on langchain.agents"""
197
  def __init__(self, llm, tools, verbose=True):