Spaces:
Sleeping
Sleeping
Revert "feat: Implement a parallel voting ensemble for LLM selection based on Jaccard similarity, replacing the sequential fallback mechanism."
Browse filesThis reverts commit f81b0fced262572d64763a9cb34c07bf38d8cc82.
- src/load/mshauri_demo.py +50 -114
src/load/mshauri_demo.py
CHANGED
|
@@ -4,7 +4,6 @@ import re
|
|
| 4 |
import sys
|
| 5 |
import io
|
| 6 |
import time
|
| 7 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 8 |
from contextlib import redirect_stdout
|
| 9 |
from typing import Any, List, Optional, Mapping
|
| 10 |
|
|
@@ -84,109 +83,35 @@ CANDIDATE_MODELS = [
|
|
| 84 |
"HuggingFaceH4/zephyr-7b-beta", # Old Reliable
|
| 85 |
]
|
| 86 |
|
| 87 |
-
|
| 88 |
-
# --- VOTING ENSEMBLE ---
|
| 89 |
-
class VotingLLM:
|
| 90 |
-
"""
|
| 91 |
-
Calls all available LLMs in parallel and selects the response with
|
| 92 |
-
the highest peer-consensus score (Jaccard word similarity).
|
| 93 |
-
|
| 94 |
-
Models that fail or time out are silently skipped, providing
|
| 95 |
-
built-in fallback behavior without a separate fallback chain.
|
| 96 |
-
If ALL models fail, raises ValueError so the agent can handle it.
|
| 97 |
-
"""
|
| 98 |
-
def __init__(self, llms: list, timeout: int = 45):
|
| 99 |
-
self.llms = llms
|
| 100 |
-
self.timeout = timeout
|
| 101 |
-
|
| 102 |
-
def invoke(self, prompt, stop=None):
|
| 103 |
-
def call_one(llm):
|
| 104 |
-
result = llm.invoke(prompt, stop=stop) if stop else llm.invoke(prompt)
|
| 105 |
-
text = result if isinstance(result, str) else result.content
|
| 106 |
-
return text.strip() if text else None
|
| 107 |
-
|
| 108 |
-
responses = []
|
| 109 |
-
with ThreadPoolExecutor(max_workers=len(self.llms)) as executor:
|
| 110 |
-
futures = {executor.submit(call_one, llm): llm for llm in self.llms}
|
| 111 |
-
try:
|
| 112 |
-
for future in as_completed(futures, timeout=self.timeout):
|
| 113 |
-
llm_name = futures[future].__class__.__name__
|
| 114 |
-
try:
|
| 115 |
-
result = future.result()
|
| 116 |
-
if result:
|
| 117 |
-
responses.append(result)
|
| 118 |
-
print(f"Vote received: {llm_name}", flush=True)
|
| 119 |
-
except Exception as e:
|
| 120 |
-
print(f"Voter {llm_name} failed: {str(e)[:80]}", flush=True)
|
| 121 |
-
except TimeoutError:
|
| 122 |
-
print("Voting timed out. Using responses collected so far.", flush=True)
|
| 123 |
-
|
| 124 |
-
if not responses:
|
| 125 |
-
raise ValueError("All LLMs failed to respond during voting.")
|
| 126 |
-
|
| 127 |
-
if len(responses) == 1:
|
| 128 |
-
return responses[0]
|
| 129 |
-
|
| 130 |
-
return self._pick_consensus(responses)
|
| 131 |
-
|
| 132 |
-
def _pick_consensus(self, responses: list) -> str:
|
| 133 |
-
"""Returns the response with the highest average Jaccard similarity to all others.
|
| 134 |
-
This is the 'centroid' of the group — the most broadly agreed-upon answer.
|
| 135 |
-
"""
|
| 136 |
-
best_score, best_response = -1.0, responses[0]
|
| 137 |
-
for i, r1 in enumerate(responses):
|
| 138 |
-
words1 = set(r1.lower().split())
|
| 139 |
-
scores = []
|
| 140 |
-
for j, r2 in enumerate(responses):
|
| 141 |
-
if i == j:
|
| 142 |
-
continue
|
| 143 |
-
words2 = set(r2.lower().split())
|
| 144 |
-
union = len(words1 | words2)
|
| 145 |
-
scores.append(len(words1 & words2) / union if union else 0.0)
|
| 146 |
-
avg = sum(scores) / len(scores) if scores else 0.0
|
| 147 |
-
if avg > best_score:
|
| 148 |
-
best_score, best_response = avg, r1
|
| 149 |
-
print(f"Consensus winner: score={best_score:.2f}, voters={len(responses)}", flush=True)
|
| 150 |
-
return best_response
|
| 151 |
-
|
| 152 |
def get_robust_llm():
|
| 153 |
-
"""Builds
|
| 154 |
-
|
| 155 |
-
All available models vote simultaneously on every query. The response
|
| 156 |
-
with the highest peer-consensus score (Jaccard word similarity) wins.
|
| 157 |
-
Models that fail during a vote are silently skipped — providing
|
| 158 |
-
built-in fallback behavior.
|
| 159 |
|
| 160 |
-
Priority
|
| 161 |
1. Hugging Face (Qwen 72B) - requires HF_TOKEN
|
| 162 |
2. Groq (Llama 70B) - requires GROQ_API_KEY
|
| 163 |
3. Gemini (1.5 Flash) - requires GEMINI_API_KEY
|
| 164 |
-
4. Local Ollama (Qwen 7B) - always
|
| 165 |
-
|
| 166 |
-
Returns:
|
| 167 |
-
(robust_llm, base_llm)
|
| 168 |
-
robust_llm: VotingLLM for the agent brain (or single model if only 1 available)
|
| 169 |
-
base_llm: Plain highest-priority model for SQLDatabaseToolkit
|
| 170 |
"""
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
-
#
|
| 174 |
hf_token = os.getenv("HF_TOKEN")
|
| 175 |
if hf_token:
|
| 176 |
-
print("HF Token found. Testing
|
| 177 |
for model_id in CANDIDATE_MODELS:
|
| 178 |
-
print(f"
|
| 179 |
try:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
print(f"
|
| 184 |
break
|
| 185 |
except Exception as e:
|
| 186 |
-
print(f"
|
| 187 |
time.sleep(0.5)
|
| 188 |
|
| 189 |
-
#
|
| 190 |
groq_key = os.getenv("GROQ_API_KEY")
|
| 191 |
if groq_key:
|
| 192 |
groq_llm = ChatGroq(
|
|
@@ -194,10 +119,14 @@ def get_robust_llm():
|
|
| 194 |
temperature=0.1,
|
| 195 |
api_key=groq_key,
|
| 196 |
)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
gemini_key = os.getenv("GEMINI_API_KEY")
|
| 202 |
if gemini_key:
|
| 203 |
gemini_llm = ChatGoogleGenerativeAI(
|
|
@@ -205,29 +134,36 @@ def get_robust_llm():
|
|
| 205 |
temperature=0.1,
|
| 206 |
google_api_key=gemini_key,
|
| 207 |
)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
local_llm = ChatOllama(model="qwen2.5:7b", temperature=0)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
#
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
return VotingLLM(available_llms), base_llm
|
| 229 |
|
| 230 |
-
# --- CLASS FOR 'Tool' ---
|
| 231 |
class SimpleTool:
|
| 232 |
"""A simple wrapper to replace langchain.tools.Tool"""
|
| 233 |
def __init__(self, name, func, description):
|
|
@@ -255,7 +191,7 @@ class PythonREPLTool(SimpleTool):
|
|
| 255 |
except Exception as e:
|
| 256 |
return f"Error executing code: {e}"
|
| 257 |
|
| 258 |
-
# --- CLASS FOR THE AGENT ---
|
| 259 |
class SimpleReActAgent:
|
| 260 |
"""A manual ReAct loop that doesn't rely on langchain.agents"""
|
| 261 |
def __init__(self, llm, tools, verbose=True):
|
|
|
|
| 4 |
import sys
|
| 5 |
import io
|
| 6 |
import time
|
|
|
|
| 7 |
from contextlib import redirect_stdout
|
| 8 |
from typing import Any, List, Optional, Mapping
|
| 9 |
|
|
|
|
| 83 |
"HuggingFaceH4/zephyr-7b-beta", # Old Reliable
|
| 84 |
]
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def get_robust_llm():
|
| 87 |
+
"""Builds an LLM with a resilient fallback cascade.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
Priority order:
|
| 90 |
1. Hugging Face (Qwen 72B) - requires HF_TOKEN
|
| 91 |
2. Groq (Llama 70B) - requires GROQ_API_KEY
|
| 92 |
3. Gemini (1.5 Flash) - requires GEMINI_API_KEY
|
| 93 |
+
4. Local Ollama (Qwen 7B) - always available
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
+
llm = None
|
| 96 |
+
fallbacks = []
|
| 97 |
|
| 98 |
+
# PRIMARY: Hugging Face (Qwen 72B)
|
| 99 |
hf_token = os.getenv("HF_TOKEN")
|
| 100 |
if hf_token:
|
| 101 |
+
print("HF Token found. Testing models for Primary LLM...", flush=True)
|
| 102 |
for model_id in CANDIDATE_MODELS:
|
| 103 |
+
print(f"Trying HF model: {model_id}...", flush=True)
|
| 104 |
try:
|
| 105 |
+
candidate_llm = HuggingFaceChat(repo_id=model_id, hf_token=hf_token, temperature=0.1)
|
| 106 |
+
candidate_llm.invoke("Ping") # Test connection
|
| 107 |
+
llm = candidate_llm
|
| 108 |
+
print(f"Primary LLM: Hugging Face ({model_id})", flush=True)
|
| 109 |
break
|
| 110 |
except Exception as e:
|
| 111 |
+
print(f"Failed {model_id}: {str(e)[:100]}...", flush=True)
|
| 112 |
time.sleep(0.5)
|
| 113 |
|
| 114 |
+
# FIRST FALLBACK: Groq (Llama-3.3-70B)
|
| 115 |
groq_key = os.getenv("GROQ_API_KEY")
|
| 116 |
if groq_key:
|
| 117 |
groq_llm = ChatGroq(
|
|
|
|
| 119 |
temperature=0.1,
|
| 120 |
api_key=groq_key,
|
| 121 |
)
|
| 122 |
+
if llm is None:
|
| 123 |
+
llm = groq_llm
|
| 124 |
+
print("Primary LLM: Groq (Llama 70B)", flush=True)
|
| 125 |
+
else:
|
| 126 |
+
fallbacks.append(groq_llm)
|
| 127 |
+
print("Added Fallback 1: Groq", flush=True)
|
| 128 |
+
|
| 129 |
+
# SECOND FALLBACK: Gemini (1.5 Flash)
|
| 130 |
gemini_key = os.getenv("GEMINI_API_KEY")
|
| 131 |
if gemini_key:
|
| 132 |
gemini_llm = ChatGoogleGenerativeAI(
|
|
|
|
| 134 |
temperature=0.1,
|
| 135 |
google_api_key=gemini_key,
|
| 136 |
)
|
| 137 |
+
if llm is None:
|
| 138 |
+
llm = gemini_llm
|
| 139 |
+
print("Primary LLM: Gemini (1.5 Flash)", flush=True)
|
| 140 |
+
else:
|
| 141 |
+
fallbacks.append(gemini_llm)
|
| 142 |
+
print("Added Fallback 2: Gemini", flush=True)
|
| 143 |
+
|
| 144 |
+
# FINAL FALLBACK: Local Ollama (Qwen 7B)
|
| 145 |
local_llm = ChatOllama(model="qwen2.5:7b", temperature=0)
|
| 146 |
+
if llm is None:
|
| 147 |
+
llm = local_llm
|
| 148 |
+
print("Primary LLM: Local Ollama (Qwen 7B)", flush=True)
|
| 149 |
+
else:
|
| 150 |
+
fallbacks.append(local_llm)
|
| 151 |
+
print("Added Final Fallback: Local Ollama", flush=True)
|
| 152 |
+
|
| 153 |
+
# Bind fallbacks so LangChain auto-routes on failure
|
| 154 |
+
if fallbacks and hasattr(llm, "with_fallbacks"):
|
| 155 |
+
try:
|
| 156 |
+
# Langchain handles the coercion between LLM and ChatModel types natively
|
| 157 |
+
# when using string prompts.
|
| 158 |
+
robust_llm = llm.with_fallbacks(fallbacks)
|
| 159 |
+
return robust_llm, llm
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"Warning: Fallback binding failed: {e}. Returning base model.", flush=True)
|
| 162 |
+
return llm, llm
|
| 163 |
|
| 164 |
+
return llm, llm
|
|
|
|
| 165 |
|
| 166 |
+
# --- 1. REPLACEMENT CLASS FOR 'Tool' ---
|
| 167 |
class SimpleTool:
|
| 168 |
"""A simple wrapper to replace langchain.tools.Tool"""
|
| 169 |
def __init__(self, name, func, description):
|
|
|
|
| 191 |
except Exception as e:
|
| 192 |
return f"Error executing code: {e}"
|
| 193 |
|
| 194 |
+
# --- 2. CLASS FOR THE AGENT ---
|
| 195 |
class SimpleReActAgent:
|
| 196 |
"""A manual ReAct loop that doesn't rely on langchain.agents"""
|
| 197 |
def __init__(self, llm, tools, verbose=True):
|