teofizzy commited on
Commit
5f850dc
·
1 Parent(s): 48535aa

feat: Enhance LLM resilience with a multi-provider fallback system, upgrade the default Ollama model to 7B, and improve vector store relevance score calculation.

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. requirements.txt +5 -1
  3. src/load/mshauri_demo.py +116 -62
Dockerfile CHANGED
@@ -42,8 +42,8 @@ CMD git clone https://huggingface.co/datasets/teofizzy/mshauri-data data_downloa
42
  echo "Starting Ollama..." && \
43
  ollama serve & \
44
  sleep 10 && \
45
- echo "Pulling Fallback Model (3B)..." && \
46
- ollama pull qwen2.5:3b && \
47
  ollama pull nomic-embed-text && \
48
  echo "Models Ready. Launching App..." && \
49
  streamlit run src/app.py --server.port 7860 --server.address 0.0.0.0
 
42
  echo "Starting Ollama..." && \
43
  ollama serve & \
44
  sleep 10 && \
45
+ echo "Pulling Fallback Model (7B)..." && \
46
+ ollama pull qwen2.5:7b && \
47
  ollama pull nomic-embed-text && \
48
  echo "Models Ready. Launching App..." && \
49
  streamlit run src/app.py --server.port 7860 --server.address 0.0.0.0
requirements.txt CHANGED
@@ -15,4 +15,8 @@ pypdf
15
  openpyxl
16
  docx2txt
17
  pymupdf4llm
18
- langchain-text-splitters
 
 
 
 
 
15
  openpyxl
16
  docx2txt
17
  pymupdf4llm
18
+ langchain-text-splitters
19
+ langchain-google-genai
20
+ langchain-google-vertexai
21
+ langchain-groq
22
+
src/load/mshauri_demo.py CHANGED
@@ -7,11 +7,11 @@ import time
7
  from contextlib import redirect_stdout
8
  from typing import Any, List, Optional, Mapping
9
 
10
- # --- NEW IMPORT FOR TRANSLATION ---
11
  from deep_translator import GoogleTranslator
12
 
13
  # Replaces HuggingFaceEndpoint with the robust Client
14
- from huggingface_hub import InferenceClient
15
  from langchain_core.language_models.llms import LLM
16
 
17
  from langchain_ollama import ChatOllama
@@ -19,12 +19,14 @@ from langchain_community.utilities import SQLDatabase
19
  from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
20
  from langchain_community.vectorstores import Chroma
21
  from langchain_community.embeddings import OllamaEmbeddings
 
 
22
 
23
  # --- CONFIGURATION ---
24
  DEFAULT_SQL_DB = "sqlite:///mshauri_fedha_v6.db"
25
  DEFAULT_VECTOR_DB = "mshauri_fedha_chroma_db"
26
  DEFAULT_EMBED_MODEL = "nomic-embed-text"
27
- DEFAULT_LLM_MODEL = "qwen2.5:3b"
28
  DEFAULT_OLLAMA_URL = "http://127.0.0.1:11434"
29
 
30
  # --- CUSTOM WRAPPER ---
@@ -77,10 +79,84 @@ CANDIDATE_MODELS = [
77
  "Qwen/Qwen2.5-7B-Instruct", # Lightweight
78
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", # Smartest Logic
79
  "meta-llama/Meta-Llama-3.1-8B-Instruct", # High Availability
80
- "mistralai/Mistral-Nemo-Instruct-2407", # Large Context
81
  "HuggingFaceH4/zephyr-7b-beta", # Old Reliable
82
  ]
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # --- 1. REPLACEMENT CLASS FOR 'Tool' ---
85
  class SimpleTool:
86
  """A simple wrapper to replace langchain.tools.Tool"""
@@ -329,44 +405,14 @@ class MultilingualAgent:
329
  def create_mshauri_agent(
330
  sql_db_path=DEFAULT_SQL_DB,
331
  vector_db_path=DEFAULT_VECTOR_DB,
332
- llm_model=DEFAULT_LLM_MODEL,
333
  ollama_url=DEFAULT_OLLAMA_URL):
334
- print(f"Initializing Mshauri Fedha...")
335
-
336
- hf_token = os.getenv("HF_TOKEN")
337
- llm = None
338
 
339
- # 1. ROBUST SERVERLESS LOADING LOOP
340
- if hf_token:
341
- print("HF Token found. Testing models...", flush=True)
342
-
343
- for model_id in CANDIDATE_MODELS:
344
- print(f"Trying model: {model_id}...", flush=True)
345
- try:
346
- # USE CUSTOM WRAPPER
347
- candidate_llm = HuggingFaceChat(
348
- repo_id=model_id,
349
- hf_token=hf_token,
350
- temperature=0.1
351
- )
352
- # TEST CALL
353
- candidate_llm.invoke("Ping")
354
-
355
- print(f"SUCCESS: Connected to {model_id}", flush=True)
356
- llm = candidate_llm
357
- break
358
- except Exception as e:
359
- print(f"Failed: {str(e)[:100]}...", flush=True)
360
- time.sleep(1)
361
-
362
- # FALLBACK
363
- if not llm:
364
- print("\nFalling back to Local CPU Ollama...", flush=True)
365
- try:
366
- llm = ChatOllama(model="qwen2.5:3b", base_url=ollama_url, temperature=0.1)
367
- except Exception as e:
368
- print(f"Error connecting to Ollama: {e}")
369
- return None
370
 
371
  # TOOLS
372
  if "sqlite" in sql_db_path:
@@ -384,28 +430,36 @@ def create_mshauri_agent(
384
 
385
  def search_docs(query):
386
  embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=ollama_url)
387
- vectorstore = Chroma(persist_directory=vector_db_path, embedding_function=embeddings)
388
-
389
- # Get L2 Distance (Lower is better)
390
- results = vectorstore.similarity_search_with_score(query, k=5)
391
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  formatted_results = []
393
-
394
- # SIGMA: This controls how strict you are.
395
- sigma = 350 # pivot point for scaling
396
-
397
- for doc, score in results:
398
- # Gaussian RBF Kernel
399
- # Formula: exp( - (score^2) / (2 * sigma^2) )
400
- confidence_float = math.exp(- (score**2) / (2 * sigma**2))
401
-
402
- # Convert to Percentage
403
- confidence_pct = int(confidence_float * 100)
404
-
405
- # Optional: Filter out low confidence noise (< 20%)
406
  if confidence_pct < 20:
407
  continue
408
-
409
  formatted_results.append(f"[Confidence: {confidence_pct}%] {doc.page_content}")
410
 
411
  if not formatted_results:
@@ -432,16 +486,16 @@ def ask_mshauri(agent, query):
432
  print(" Agent not initialized.")
433
  return
434
 
435
- print(f"\n User: {query}")
436
  print("-" * 40)
437
 
438
  try:
439
  response = agent.invoke({"input": query})
440
  print("-" * 40)
441
- print(f" Mshauri: {response['output']}")
442
  return response['output']
443
  except Exception as e:
444
- print(f" Error during execution: {e}")
445
  return None
446
 
447
  if __name__ == "__main__":
 
7
  from contextlib import redirect_stdout
8
  from typing import Any, List, Optional, Mapping
9
 
10
+ # --- IMPORT FOR TRANSLATION ---
11
  from deep_translator import GoogleTranslator
12
 
13
  # Replaces HuggingFaceEndpoint with the robust Client
14
+ from huggingface_hub import InferenceClient
15
  from langchain_core.language_models.llms import LLM
16
 
17
  from langchain_ollama import ChatOllama
 
19
  from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
20
  from langchain_community.vectorstores import Chroma
21
  from langchain_community.embeddings import OllamaEmbeddings
22
+ from langchain_groq import ChatGroq
23
+ from langchain_google_genai import ChatGoogleGenerativeAI
24
 
25
  # --- CONFIGURATION ---
26
  DEFAULT_SQL_DB = "sqlite:///mshauri_fedha_v6.db"
27
  DEFAULT_VECTOR_DB = "mshauri_fedha_chroma_db"
28
  DEFAULT_EMBED_MODEL = "nomic-embed-text"
29
+ DEFAULT_LLM_MODEL = "qwen2.5:7b"
30
  DEFAULT_OLLAMA_URL = "http://127.0.0.1:11434"
31
 
32
  # --- CUSTOM WRAPPER ---
 
79
  "Qwen/Qwen2.5-7B-Instruct", # Lightweight
80
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", # Smartest Logic
81
  "meta-llama/Meta-Llama-3.1-8B-Instruct", # High Availability
82
+ "mistralai/Mistral-Nemo-Instruct-2407", # Large Context
83
  "HuggingFaceH4/zephyr-7b-beta", # Old Reliable
84
  ]
85
 
86
+
87
+ def get_robust_llm():
88
+ """Builds an LLM with a resilient fallback cascade.
89
+
90
+ Priority order:
91
+ 1. Hugging Face (Qwen 72B) - requires HF_TOKEN
92
+ 2. Groq (Llama 70B) - requires GROQ_API_KEY
93
+ 3. Gemini (1.5 Flash) - requires GEMINI_API_KEY
94
+ 4. Local Ollama (Qwen 3B) - always available
95
+ """
96
+ llm = None
97
+ fallbacks = []
98
+
99
+ # PRIMARY: Hugging Face (Qwen 72B)
100
+ hf_token = os.getenv("HF_TOKEN")
101
+ if hf_token:
102
+ print("HF Token found. Testing models for Primary LLM...", flush=True)
103
+ for model_id in CANDIDATE_MODELS:
104
+ print(f"Trying HF model: {model_id}...", flush=True)
105
+ try:
106
+ candidate_llm = HuggingFaceChat(repo_id=model_id, hf_token=hf_token, temperature=0.1)
107
+ candidate_llm.invoke("Ping") # Test connection
108
+ llm = candidate_llm
109
+ print(f"Primary LLM: Hugging Face ({model_id})", flush=True)
110
+ break
111
+ except Exception as e:
112
+ print(f"Failed {model_id}: {str(e)[:100]}...", flush=True)
113
+ time.sleep(0.5)
114
+
115
+ # FIRST FALLBACK: Groq (Llama-3.3-70B)
116
+ groq_key = os.getenv("GROQ_API_KEY")
117
+ if groq_key:
118
+ groq_llm = ChatGroq(
119
+ model="llama-3.3-70b-versatile",
120
+ temperature=0.1,
121
+ api_key=groq_key,
122
+ )
123
+ if llm is None:
124
+ llm = groq_llm
125
+ print("Primary LLM: Groq (Llama 70B)", flush=True)
126
+ else:
127
+ fallbacks.append(groq_llm)
128
+ print("Added Fallback 1: Groq", flush=True)
129
+
130
+ # SECOND FALLBACK: Gemini (1.5 Flash)
131
+ gemini_key = os.getenv("GEMINI_API_KEY")
132
+ if gemini_key:
133
+ gemini_llm = ChatGoogleGenerativeAI(
134
+ model="gemini-1.5-flash",
135
+ temperature=0.1,
136
+ google_api_key=gemini_key,
137
+ )
138
+ if llm is None:
139
+ llm = gemini_llm
140
+ print("Primary LLM: Gemini (1.5 Flash)", flush=True)
141
+ else:
142
+ fallbacks.append(gemini_llm)
143
+ print("Added Fallback 2: Gemini", flush=True)
144
+
145
+ # FINAL FALLBACK: Local Ollama (Qwen 7B)
146
+ local_llm = ChatOllama(model="qwen2.5:7b", temperature=0)
147
+ if llm is None:
148
+ llm = local_llm
149
+ print("Primary LLM: Local Ollama (Qwen 7B)", flush=True)
150
+ else:
151
+ fallbacks.append(local_llm)
152
+ print("Added Final Fallback: Local Ollama", flush=True)
153
+
154
+ # Bind fallbacks so LangChain auto-routes on failure
155
+ if fallbacks and hasattr(llm, "with_fallbacks"):
156
+ return llm.with_fallbacks(fallbacks)
157
+
158
+ return llm
159
+
160
  # --- 1. REPLACEMENT CLASS FOR 'Tool' ---
161
  class SimpleTool:
162
  """A simple wrapper to replace langchain.tools.Tool"""
 
405
  def create_mshauri_agent(
406
  sql_db_path=DEFAULT_SQL_DB,
407
  vector_db_path=DEFAULT_VECTOR_DB,
 
408
  ollama_url=DEFAULT_OLLAMA_URL):
409
+ print("Initializing Mshauri Fedha...", flush=True)
 
 
 
410
 
411
+ # Build a resilient LLM with automatic fallback cascade
412
+ llm = get_robust_llm()
413
+ if llm is None:
414
+ print("Error: No LLM could be initialised. Check your API keys and Ollama service.")
415
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  # TOOLS
418
  if "sqlite" in sql_db_path:
 
430
 
431
  def search_docs(query):
432
  embeddings = OllamaEmbeddings(model=DEFAULT_EMBED_MODEL, base_url=ollama_url)
433
+
434
+ # The Chroma collection was created with default L2 distance (no cosine metadata).
435
+ # nomic-embed-text produces unit-normalised vectors, so we can recover exact cosine
436
+ # similarity from L2 distance using: cosine_sim = 1 - (l2_distance^2 / 2)
437
+ # We pass this as a custom relevance_score_fn so LangChain uses it instead of
438
+ # its default L2 normalisation formula (1 - l2/2), giving true cosine percentages.
439
+ def cosine_from_l2(l2_distance: float) -> float:
440
+ cosine_sim = 1.0 - (l2_distance ** 2) / 2.0
441
+ # Clamp to [0, 1] to guard against floating-point edge cases
442
+ return max(0.0, min(1.0, cosine_sim))
443
+
444
+ vectorstore = Chroma(
445
+ persist_directory=vector_db_path,
446
+ embedding_function=embeddings,
447
+ relevance_score_fn=cosine_from_l2,
448
+ )
449
+
450
+ # Returns (doc, cosine_similarity) pairs in [0, 1]
451
+ results = vectorstore.similarity_search_with_relevance_scores(query, k=5)
452
+
453
  formatted_results = []
454
+
455
+ for doc, cosine_sim in results:
456
+ # Map cosine similarity [0, 1] directly to a percentage
457
+ confidence_pct = int(cosine_sim * 100)
458
+
459
+ # Filter out low-confidence noise (< 20%)
 
 
 
 
 
 
 
460
  if confidence_pct < 20:
461
  continue
462
+
463
  formatted_results.append(f"[Confidence: {confidence_pct}%] {doc.page_content}")
464
 
465
  if not formatted_results:
 
486
  print(" Agent not initialized.")
487
  return
488
 
489
+ print(f"\nUser: {query}")
490
  print("-" * 40)
491
 
492
  try:
493
  response = agent.invoke({"input": query})
494
  print("-" * 40)
495
+ print(f"Mshauri: {response['output']}")
496
  return response['output']
497
  except Exception as e:
498
+ print(f"Error during execution: {e}")
499
  return None
500
 
501
  if __name__ == "__main__":