Spaces:
Sleeping
Sleeping
feat: Enhance LLM resilience with a multi-provider fallback system, upgrade the default Ollama model to 7B, and improve vector store relevance score calculation.
Browse files- Dockerfile +2 -2
- requirements.txt +5 -1
- src/load/mshauri_demo.py +116 -62
Dockerfile
CHANGED
|
@@ -42,8 +42,8 @@ CMD git clone https://huggingface.co/datasets/teofizzy/mshauri-data data_downloa
|
|
| 42 |
echo "Starting Ollama..." && \
|
| 43 |
ollama serve & \
|
| 44 |
sleep 10 && \
|
| 45 |
-
echo "Pulling Fallback Model (
|
| 46 |
-
ollama pull qwen2.5:
|
| 47 |
ollama pull nomic-embed-text && \
|
| 48 |
echo "Models Ready. Launching App..." && \
|
| 49 |
streamlit run src/app.py --server.port 7860 --server.address 0.0.0.0
|
|
|
|
| 42 |
echo "Starting Ollama..." && \
|
| 43 |
ollama serve & \
|
| 44 |
sleep 10 && \
|
| 45 |
+
echo "Pulling Fallback Model (7B)..." && \
|
| 46 |
+
ollama pull qwen2.5:7b && \
|
| 47 |
ollama pull nomic-embed-text && \
|
| 48 |
echo "Models Ready. Launching App..." && \
|
| 49 |
streamlit run src/app.py --server.port 7860 --server.address 0.0.0.0
|
requirements.txt
CHANGED
|
@@ -15,4 +15,8 @@ pypdf
|
|
| 15 |
openpyxl
|
| 16 |
docx2txt
|
| 17 |
pymupdf4llm
|
| 18 |
-
langchain-text-splitters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
openpyxl
|
| 16 |
docx2txt
|
| 17 |
pymupdf4llm
|
| 18 |
+
langchain-text-splitters
|
| 19 |
+
langchain-google-genai
|
| 20 |
+
langchain-google-vertexai
|
| 21 |
+
langchain-groq
|
| 22 |
+
|
src/load/mshauri_demo.py
CHANGED
|
@@ -7,11 +7,11 @@ import time
|
|
| 7 |
from contextlib import redirect_stdout
|
| 8 |
from typing import Any, List, Optional, Mapping
|
| 9 |
|
| 10 |
-
# ---
|
| 11 |
from deep_translator import GoogleTranslator
|
| 12 |
|
| 13 |
# Replaces HuggingFaceEndpoint with the robust Client
|
| 14 |
-
from huggingface_hub import InferenceClient
|
| 15 |
from langchain_core.language_models.llms import LLM
|
| 16 |
|
| 17 |
from langchain_ollama import ChatOllama
|
|
@@ -19,12 +19,14 @@ from langchain_community.utilities import SQLDatabase
|
|
| 19 |
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
| 20 |
from langchain_community.vectorstores import Chroma
|
| 21 |
from langchain_community.embeddings import OllamaEmbeddings
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# --- CONFIGURATION ---
|
| 24 |
DEFAULT_SQL_DB = "sqlite:///mshauri_fedha_v6.db"
|
| 25 |
DEFAULT_VECTOR_DB = "mshauri_fedha_chroma_db"
|
| 26 |
DEFAULT_EMBED_MODEL = "nomic-embed-text"
|
| 27 |
-
DEFAULT_LLM_MODEL = "qwen2.5:
|
| 28 |
DEFAULT_OLLAMA_URL = "http://127.0.0.1:11434"
|
| 29 |
|
| 30 |
# --- CUSTOM WRAPPER ---
|
|
@@ -77,10 +79,84 @@ CANDIDATE_MODELS = [
|
|
| 77 |
"Qwen/Qwen2.5-7B-Instruct", # Lightweight
|
| 78 |
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B", # Smartest Logic
|
| 79 |
"meta-llama/Meta-Llama-3.1-8B-Instruct", # High Availability
|
| 80 |
-
"mistralai/Mistral-Nemo-Instruct-2407",
|
| 81 |
"HuggingFaceH4/zephyr-7b-beta", # Old Reliable
|
| 82 |
]
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# --- 1. REPLACEMENT CLASS FOR 'Tool' ---
|
| 85 |
class SimpleTool:
|
| 86 |
"""A simple wrapper to replace langchain.tools.Tool"""
|
|
@@ -329,44 +405,14 @@ class MultilingualAgent:
|
|
| 329 |
def create_mshauri_agent(
|
| 330 |
sql_db_path=DEFAULT_SQL_DB,
|
| 331 |
vector_db_path=DEFAULT_VECTOR_DB,
|
| 332 |
-
llm_model=DEFAULT_LLM_MODEL,
|
| 333 |
ollama_url=DEFAULT_OLLAMA_URL):
|
| 334 |
-
print(
|
| 335 |
-
|
| 336 |
-
hf_token = os.getenv("HF_TOKEN")
|
| 337 |
-
llm = None
|
| 338 |
|
| 339 |
-
#
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
print(f"Trying model: {model_id}...", flush=True)
|
| 345 |
-
try:
|
| 346 |
-
# USE CUSTOM WRAPPER
|
| 347 |
-
candidate_llm = HuggingFaceChat(
|
| 348 |
-
repo_id=model_id,
|
| 349 |
-
hf_token=hf_token,
|
| 350 |
-
temperature=0.1
|
| 351 |
-
)
|
| 352 |
-
# TEST CALL
|
| 353 |
-
candidate_llm.invoke("Ping")
|
| 354 |
-
|
| 355 |
-
print(f"SUCCESS: Connected to {model_id}", flush=True)
|
| 356 |
-
llm = candidate_llm
|
| 357 |
-
break
|
| 358 |
-
except Exception as e:
|
| 359 |
-
print(f"Failed: {str(e)[:100]}...", flush=True)
|
| 360 |
-
time.sleep(1)
|
| 361 |
-
|
| 362 |
-
# FALLBACK
|
| 363 |
-
if not llm:
|
| 364 |
-
print("\nFalling back to Local CPU Ollama...", flush=True)
|
| 365 |
-
try:
|
| 366 |
-
llm = ChatOllama(model="qwen2.5:3b", base_url=ollama_url, temperature=0.1)
|
| 367 |
-
except Exception as e:
|
| 368 |
-
print(f"Error connecting to Ollama: {e}")
|
| 369 |
-
return None
|
| 370 |
|
| 371 |
# TOOLS
|
| 372 |
if "sqlite" in sql_db_path:
|
|
@@ -384,28 +430,36 @@ def create_mshauri_agent(
|
|
| 384 |
|
| 385 |
def search_docs(query):
|
| 386 |
embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=ollama_url)
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
#
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
formatted_results = []
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
#
|
| 399 |
-
# Formula: exp( - (score^2) / (2 * sigma^2) )
|
| 400 |
-
confidence_float = math.exp(- (score**2) / (2 * sigma**2))
|
| 401 |
-
|
| 402 |
-
# Convert to Percentage
|
| 403 |
-
confidence_pct = int(confidence_float * 100)
|
| 404 |
-
|
| 405 |
-
# Optional: Filter out low confidence noise (< 20%)
|
| 406 |
if confidence_pct < 20:
|
| 407 |
continue
|
| 408 |
-
|
| 409 |
formatted_results.append(f"[Confidence: {confidence_pct}%] {doc.page_content}")
|
| 410 |
|
| 411 |
if not formatted_results:
|
|
@@ -432,16 +486,16 @@ def ask_mshauri(agent, query):
|
|
| 432 |
print(" Agent not initialized.")
|
| 433 |
return
|
| 434 |
|
| 435 |
-
print(f"\
|
| 436 |
print("-" * 40)
|
| 437 |
|
| 438 |
try:
|
| 439 |
response = agent.invoke({"input": query})
|
| 440 |
print("-" * 40)
|
| 441 |
-
print(f"
|
| 442 |
return response['output']
|
| 443 |
except Exception as e:
|
| 444 |
-
print(f"
|
| 445 |
return None
|
| 446 |
|
| 447 |
if __name__ == "__main__":
|
|
|
|
| 7 |
from contextlib import redirect_stdout
|
| 8 |
from typing import Any, List, Optional, Mapping
|
| 9 |
|
| 10 |
+
# --- IMPORT FOR TRANSLATION ---
|
| 11 |
from deep_translator import GoogleTranslator
|
| 12 |
|
| 13 |
# Replaces HuggingFaceEndpoint with the robust Client
|
| 14 |
+
from huggingface_hub import InferenceClient
|
| 15 |
from langchain_core.language_models.llms import LLM
|
| 16 |
|
| 17 |
from langchain_ollama import ChatOllama
|
|
|
|
| 19 |
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
| 20 |
from langchain_community.vectorstores import Chroma
|
| 21 |
from langchain_community.embeddings import OllamaEmbeddings
|
| 22 |
+
from langchain_groq import ChatGroq
|
| 23 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 24 |
|
| 25 |
# --- CONFIGURATION ---
|
| 26 |
DEFAULT_SQL_DB = "sqlite:///mshauri_fedha_v6.db"
|
| 27 |
DEFAULT_VECTOR_DB = "mshauri_fedha_chroma_db"
|
| 28 |
DEFAULT_EMBED_MODEL = "nomic-embed-text"
|
| 29 |
+
DEFAULT_LLM_MODEL = "qwen2.5:7b"
|
| 30 |
DEFAULT_OLLAMA_URL = "http://127.0.0.1:11434"
|
| 31 |
|
| 32 |
# --- CUSTOM WRAPPER ---
|
|
|
|
| 79 |
"Qwen/Qwen2.5-7B-Instruct", # Lightweight
|
| 80 |
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B", # Smartest Logic
|
| 81 |
"meta-llama/Meta-Llama-3.1-8B-Instruct", # High Availability
|
| 82 |
+
"mistralai/Mistral-Nemo-Instruct-2407", # Large Context
|
| 83 |
"HuggingFaceH4/zephyr-7b-beta", # Old Reliable
|
| 84 |
]
|
| 85 |
|
| 86 |
+
|
| 87 |
+
def get_robust_llm():
|
| 88 |
+
"""Builds an LLM with a resilient fallback cascade.
|
| 89 |
+
|
| 90 |
+
Priority order:
|
| 91 |
+
1. Hugging Face (Qwen 72B) - requires HF_TOKEN
|
| 92 |
+
2. Groq (Llama 70B) - requires GROQ_API_KEY
|
| 93 |
+
3. Gemini (1.5 Flash) - requires GEMINI_API_KEY
|
| 94 |
+
4. Local Ollama (Qwen 3B) - always available
|
| 95 |
+
"""
|
| 96 |
+
llm = None
|
| 97 |
+
fallbacks = []
|
| 98 |
+
|
| 99 |
+
# PRIMARY: Hugging Face (Qwen 72B)
|
| 100 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 101 |
+
if hf_token:
|
| 102 |
+
print("HF Token found. Testing models for Primary LLM...", flush=True)
|
| 103 |
+
for model_id in CANDIDATE_MODELS:
|
| 104 |
+
print(f"Trying HF model: {model_id}...", flush=True)
|
| 105 |
+
try:
|
| 106 |
+
candidate_llm = HuggingFaceChat(repo_id=model_id, hf_token=hf_token, temperature=0.1)
|
| 107 |
+
candidate_llm.invoke("Ping") # Test connection
|
| 108 |
+
llm = candidate_llm
|
| 109 |
+
print(f"Primary LLM: Hugging Face ({model_id})", flush=True)
|
| 110 |
+
break
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Failed {model_id}: {str(e)[:100]}...", flush=True)
|
| 113 |
+
time.sleep(0.5)
|
| 114 |
+
|
| 115 |
+
# FIRST FALLBACK: Groq (Llama-3.3-70B)
|
| 116 |
+
groq_key = os.getenv("GROQ_API_KEY")
|
| 117 |
+
if groq_key:
|
| 118 |
+
groq_llm = ChatGroq(
|
| 119 |
+
model="llama-3.3-70b-versatile",
|
| 120 |
+
temperature=0.1,
|
| 121 |
+
api_key=groq_key,
|
| 122 |
+
)
|
| 123 |
+
if llm is None:
|
| 124 |
+
llm = groq_llm
|
| 125 |
+
print("Primary LLM: Groq (Llama 70B)", flush=True)
|
| 126 |
+
else:
|
| 127 |
+
fallbacks.append(groq_llm)
|
| 128 |
+
print("Added Fallback 1: Groq", flush=True)
|
| 129 |
+
|
| 130 |
+
# SECOND FALLBACK: Gemini (1.5 Flash)
|
| 131 |
+
gemini_key = os.getenv("GEMINI_API_KEY")
|
| 132 |
+
if gemini_key:
|
| 133 |
+
gemini_llm = ChatGoogleGenerativeAI(
|
| 134 |
+
model="gemini-1.5-flash",
|
| 135 |
+
temperature=0.1,
|
| 136 |
+
google_api_key=gemini_key,
|
| 137 |
+
)
|
| 138 |
+
if llm is None:
|
| 139 |
+
llm = gemini_llm
|
| 140 |
+
print("Primary LLM: Gemini (1.5 Flash)", flush=True)
|
| 141 |
+
else:
|
| 142 |
+
fallbacks.append(gemini_llm)
|
| 143 |
+
print("Added Fallback 2: Gemini", flush=True)
|
| 144 |
+
|
| 145 |
+
# FINAL FALLBACK: Local Ollama (Qwen 7B)
|
| 146 |
+
local_llm = ChatOllama(model="qwen2.5:7b", temperature=0)
|
| 147 |
+
if llm is None:
|
| 148 |
+
llm = local_llm
|
| 149 |
+
print("Primary LLM: Local Ollama (Qwen 7B)", flush=True)
|
| 150 |
+
else:
|
| 151 |
+
fallbacks.append(local_llm)
|
| 152 |
+
print("Added Final Fallback: Local Ollama", flush=True)
|
| 153 |
+
|
| 154 |
+
# Bind fallbacks so LangChain auto-routes on failure
|
| 155 |
+
if fallbacks and hasattr(llm, "with_fallbacks"):
|
| 156 |
+
return llm.with_fallbacks(fallbacks)
|
| 157 |
+
|
| 158 |
+
return llm
|
| 159 |
+
|
| 160 |
# --- 1. REPLACEMENT CLASS FOR 'Tool' ---
|
| 161 |
class SimpleTool:
|
| 162 |
"""A simple wrapper to replace langchain.tools.Tool"""
|
|
|
|
| 405 |
def create_mshauri_agent(
|
| 406 |
sql_db_path=DEFAULT_SQL_DB,
|
| 407 |
vector_db_path=DEFAULT_VECTOR_DB,
|
|
|
|
| 408 |
ollama_url=DEFAULT_OLLAMA_URL):
|
| 409 |
+
print("Initializing Mshauri Fedha...", flush=True)
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
# Build a resilient LLM with automatic fallback cascade
|
| 412 |
+
llm = get_robust_llm()
|
| 413 |
+
if llm is None:
|
| 414 |
+
print("Error: No LLM could be initialised. Check your API keys and Ollama service.")
|
| 415 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
# TOOLS
|
| 418 |
if "sqlite" in sql_db_path:
|
|
|
|
| 430 |
|
| 431 |
def search_docs(query):
|
| 432 |
embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=ollama_url)
|
| 433 |
+
|
| 434 |
+
# The Chroma collection was created with default L2 distance (no cosine metadata).
|
| 435 |
+
# nomic-embed-text produces unit-normalised vectors, so we can recover exact cosine
|
| 436 |
+
# similarity from L2 distance using: cosine_sim = 1 - (l2_distance^2 / 2)
|
| 437 |
+
# We pass this as a custom relevance_score_fn so LangChain uses it instead of
|
| 438 |
+
# its default L2 normalisation formula (1 - l2/2), giving true cosine percentages.
|
| 439 |
+
def cosine_from_l2(l2_distance: float) -> float:
|
| 440 |
+
cosine_sim = 1.0 - (l2_distance ** 2) / 2.0
|
| 441 |
+
# Clamp to [0, 1] to guard against floating-point edge cases
|
| 442 |
+
return max(0.0, min(1.0, cosine_sim))
|
| 443 |
+
|
| 444 |
+
vectorstore = Chroma(
|
| 445 |
+
persist_directory=vector_db_path,
|
| 446 |
+
embedding_function=embeddings,
|
| 447 |
+
relevance_score_fn=cosine_from_l2,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Returns (doc, cosine_similarity) pairs in [0, 1]
|
| 451 |
+
results = vectorstore.similarity_search_with_relevance_scores(query, k=5)
|
| 452 |
+
|
| 453 |
formatted_results = []
|
| 454 |
+
|
| 455 |
+
for doc, cosine_sim in results:
|
| 456 |
+
# Map cosine similarity [0, 1] directly to a percentage
|
| 457 |
+
confidence_pct = int(cosine_sim * 100)
|
| 458 |
+
|
| 459 |
+
# Filter out low-confidence noise (< 20%)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
if confidence_pct < 20:
|
| 461 |
continue
|
| 462 |
+
|
| 463 |
formatted_results.append(f"[Confidence: {confidence_pct}%] {doc.page_content}")
|
| 464 |
|
| 465 |
if not formatted_results:
|
|
|
|
| 486 |
print(" Agent not initialized.")
|
| 487 |
return
|
| 488 |
|
| 489 |
+
print(f"\nUser: {query}")
|
| 490 |
print("-" * 40)
|
| 491 |
|
| 492 |
try:
|
| 493 |
response = agent.invoke({"input": query})
|
| 494 |
print("-" * 40)
|
| 495 |
+
print(f"Mshauri: {response['output']}")
|
| 496 |
return response['output']
|
| 497 |
except Exception as e:
|
| 498 |
+
print(f"Error during execution: {e}")
|
| 499 |
return None
|
| 500 |
|
| 501 |
if __name__ == "__main__":
|