import os import gc import logging import unicodedata import re import collections from llama_index.core.schema import NodeWithScore # Explicitly import NodeWithScore for clarity from typing import List, Dict, Any, Optional, Callable import time import random import torch import unicodedata import re from functools import lru_cache from typing import Any, List, Optional import torch from weaviate.classes.init import Auth from weaviate.agents.query import QueryAgent from weaviate_agents.classes import QueryAgentCollectionConfig import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) # Chainlit import chainlit as cl from chainlit.input_widget import Select from chainlit.types import ThreadDict from chainlit.input_widget import Select import asyncio # Lazy imports for heavy libraries -- imported inside startup # from llama_index.core.settings import Settings # from llama_index.core import Document, StorageContext, VectorStoreIndex # from llama_index.vector_stores.weaviate import WeaviateVectorStore logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # Best practice: store your credentials in environment variables WEAVIATE_URL = "puifkdlvt8kgh7ga2rtuca.c0.us-west3.gcp.weaviate.cloud" WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY") GEMINI_API_KEY=os.getenv("api_key") # For CPU-only environments os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("CUDA_VISIBLE_DEVICES", "-1") import collections from typing import List, Any, Optional import uuid import uuid # Added missing import from llama_index.core.schema import TextNode, NodeWithScore # Added missing imports # Use LlamaIndex standard schemas if available, otherwise simulate them try: from llama_index.core.schema import NodeWithScore, TextNode except ImportError: # Fallback if you don't have llama_index installed but want the structure class TextNode: def __init__(self, text: str, id_: str = None): self.text = text self.node_id = id_ or str(uuid.uuid4()) class NodeWithScore: def __init__(self, node: TextNode, score: float = None): self.node = node self.score = score @property def text(self): return self.node.text class WeaviateAgentRetriever: def __init__(self, agent): self.agent = agent def retrieve(self, queries: List[str], top_k=5, llm=None) -> List[NodeWithScore]: results = [] # FIX 1: Ensure we handle single strings gracefully, though we expect a list if isinstance(queries, str): queries = [queries] for q in queries: # Call the agent response = self.agent.search(q, limit=top_k) # Parse results # Note: Adjust 'response.search_results.objects' based on actual Weaviate response structure if hasattr(response, 'search_results') and hasattr(response.search_results, 'objects'): iterator = response.search_results.objects else: iterator = [] # Handle empty/error cases safely for obj in iterator: text = obj.properties.get("text") or obj.properties.get("content") if text: # FIX 2: Create a TextNode first node = TextNode(text=text) # FIX 3: Wrap it in NodeWithScore (default score to 1.0 or fetch from Weaviate metadata) # Weaviate usually returns certainty or distance, usually mapped to score # Corrected: Access certainty from obj._additional score = obj.metadata.get("score", 1.0) results.append(NodeWithScore(node=node, score=score)) return results # ----------------------------- # Utilities # ----------------------------- def normalize_yoruba(text: str) -> str: if not text: return "" text = unicodedata.normalize("NFC", text) text = re.sub(r"\s+", " ", text).strip() return text def free_memory(): gc.collect() try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass # Paste your HuggingFace access token here HUGGINGFACE_TOKEN = os.getenv("hf_api") try: login(HUGGINGFACE_TOKEN) except Exception as e: print(f"Warning: HuggingFace login failed: {e}") print("Continuing without authentication...") from llama_index.core.embeddings import BaseEmbedding from typing import Any, List import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel class AfriBERTaEmbedding(BaseEmbedding): _model: Any = None _tokenizer: Any = None _device: Any = None def __init__( self, model_name: str = "Davlan/afro-xlmr-mini", **kwargs: Any ) -> None: super().__init__(model_name=model_name, **kwargs) # CPU device self._device = torch.device("cpu") # 2. Load tokenizer (use_fast=False to avoid tokenizer.json issues) self._tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=False ) # 3. Load model on CPU self._model = AutoModel.from_pretrained(model_name).to(self._device) self._model.eval() def _mean_pooling(self, token_embeddings, attention_mask): input_mask_expanded = ( attention_mask.unsqueeze(-1) .expand(token_embeddings.size()) .float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def _embed(self, texts: List[str]) -> List[List[float]]: """Core embedding logic""" inputs = self._tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=512 # important for RAG ).to(self._device) with torch.no_grad(): outputs = self._model(**inputs) # Pool embeddings = self._mean_pooling( outputs.last_hidden_state, inputs["attention_mask"] ) # Normalize embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings.tolist() # --- LlamaIndex required methods --- def _get_query_embedding(self, query: str) -> List[float]: return self._embed([query])[0] def _get_text_embedding(self, text: str) -> List[float]: return self._embed([text])[0] def _get_text_embedding_batch(self, texts: List[str]) -> List[List[float]]: return self._embed(texts) async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: return self._get_text_embedding(text) # ----------------------------- # Safe LLM completion wrapper # ----------------------------- logger = logging.getLogger(__name__) def safe_llm_complete(llm, prompt: str) -> Optional[str]: if llm is None: logger.warning("LLM is None — cannot complete prompt.") return None try: if hasattr(llm, "complete"): resp = llm.complete(prompt) if resp is None: return None if hasattr(resp, "text") and resp.text: return str(resp.text).strip() if hasattr(resp, "output_text") and resp.output_text: return str(resp.output_text).strip() return str(resp).strip() if hasattr(llm, "chat"): resp = llm.chat(messages=[{"role": "user", "content": prompt}]) # Try common response shapes if hasattr(resp, "message") and hasattr(resp.message, "content"): return str(resp.message.content).strip() if hasattr(resp, "output_text"): return str(resp.output_text).strip() return str(resp).strip() # Generic fallback return str(llm(prompt)).strip() except Exception as e: logger.warning(f"Safe LLM call failed: {e}") return None # ----------------------------- # Gemini loader (cached) # ----------------------------- @lru_cache(maxsize=4) def load_gemini_llm(api_key: str, model: str = "models/gemini-2.5-flash"): try: from llama_index.llms.gemini import Gemini llm = Gemini(model=model, api_key=api_key) logger.info("Gemini LLM loaded") return llm except Exception as e: logger.warning(f"Could not load Gemini LLM: {e}") return None class PreRetrievalModule: """Expands query using HyDE (Hypothetical Document Embedding).""" def __init__(self, llm: Optional[Any] = None, enable_hyde: bool = True, enable_expansion: bool = True): self.llm = llm self.enable_hyde = enable_hyde self.enable_expansion = enable_expansion def process(self, query: str) -> List[str]: """Apply HyDE and query expansion to produce query variants.""" queries = [query] if self.enable_hyde and self.llm: hyde_prompt = f"Write a short factual paragraph that could answer: '{query}'" hypo_doc = safe_llm_complete(self.llm, hyde_prompt) if hypo_doc: queries.append(hypo_doc) if self.enable_expansion and self.llm: expansion_prompt = f"Expand this question with related terms: '{query}'" expansion = safe_llm_complete(self.llm, expansion_prompt) if expansion: queries.append(expansion) return queries class RetrievalModule: """Combines dense and sparse retrieval (hybrid).""" def __init__(self, dense_retriever: Any, sparse_retriever: Optional[Any] = None, hybrid_alpha: float = 0.5): self.dense = dense_retriever self.sparse = sparse_retriever self.hybrid_alpha = hybrid_alpha def retrieve(self, queries: List[str], top_k: int = 10) -> List[NodeWithScore]: # Adjusted return type hint all_results_raw = [] # Store raw NodeWithScore objects for q in queries: # Dense retrieval dense_results = self.dense.retrieve(q) all_results_raw.extend(dense_results) if self.sparse: # Sparse retrieval sparse_results = self.sparse.retrieve(q) all_results_raw.extend(sparse_results) # Use a dictionary to deduplicate based on node_id, preserving NodeWithScore unique_nodes_map = collections.OrderedDict() for node_with_score in all_results_raw: # Assuming node_with_score.node.node_id is unique enough unique_nodes_map[node_with_score.node.node_id] = node_with_score # Return list of NodeWithScore objects, limited by top_k return list(unique_nodes_map.values())[:top_k] class PostRetrievalModule: """Applies reranking, filtering, and compression.""" def __init__(self, embed_model): self.embed_model = embed_model def rerank(self, nodes, query, top_k=3): if not nodes: return [] # Use _embed and convert to tensor query_emb = torch.tensor(self.embed_model._embed(query)[0]) node_texts = [n.text for n in nodes] node_embs = torch.tensor(self.embed_model._embed(node_texts)) scores = torch.nn.functional.cosine_similarity(query_emb.unsqueeze(0), node_embs) ranked = sorted(zip(nodes, scores.tolist()), key=lambda x: x[1], reverse=True) return [n for n, _ in ranked[:top_k]] def context_filter(self, nodes, diversity_threshold=0.8): unique_nodes, seen = [], [] for n in nodes: text = n.text.strip() if not text: continue # Use _embed and convert to tensor emb = torch.tensor(self.embed_model._embed(text[:512])[0]) if all(torch.nn.functional.cosine_similarity(emb, s, dim=0) < diversity_threshold for s in seen): unique_nodes.append(n) seen.append(emb) return unique_nodes def compress_contexts(self, nodes, max_len=1500): texts = [] total_len = 0 for n in nodes: t = n.text.strip() if total_len + len(t) > max_len: break texts.append(t) total_len += len(t) return "\n\n".join(texts) class GenerationModule: def __init__(self, llm, verify: bool = True): self.llm = llm self.verify = verify def generate_with_fallback(self, query: str, domain: str = "General", context: str = "") -> str: fallback_prompt = f""" Ìwọ jẹ́ ọ̀jọ̀gbọ́n nínú {domain}. Fún ìdáhùn kan sí ìbéèrè yìí ní èdè Yorùbá: Ìbéèrè: {query} Jọwọ pèsè ìdáhùn tó dájú. Answer in clear Yoruba with proper paragraph spacing """ resp = safe_llm_complete(self.llm, fallback_prompt) if resp: return resp return "⚠️ Ko ṣee ṣe lati gba ìdáhùn lọwọ LLM." def generate(self, query: str, context: str, domain: str = "General") -> str: prompt = build_yoruba_prompt(query=query, context=context, domain=domain) raw = safe_llm_complete(self.llm, prompt) if raw is None: return "⚠️ Ko ṣee ṣe lati gba ìdáhùn lọwọ LLM." return raw.strip() # ----------------------------- # Dummy retriever fallback # ----------------------------- class DummyRetriever: def retrieve(self, queries, top_k=10): return [] class OrchestrationModule: """Coordinates all plug-in modules.""" def __init__(self, pre, retriever, post, generator,llm_fallback_threshold: float = 0.3): self.pre = pre self.retriever = retriever self.post = post self.generator = generator self.llm_fallback_threshold = llm_fallback_threshold def modular_query(self, question: str, domain: str = "General", top_k: int = 5): query = normalize_yoruba(question) queries=self.pre.process(query) print(f"🔍 Expanded Queries: {queries}") nodes=[] if self.retriever: nodes= self.retriever.retrieve(queries, top_k=top_k) if not nodes: return { "question": question, "answer": self.generator.generate_with_fallback(question, domain, ""), "context": "", "num_docs": 0, "expanded_queries": queries, "mode": "llm_fallback", "show_warning": True } print(f"📚 Retrieved {len(nodes)} documents") # Step 3: Post-Retrieval Filtering joined_query=" ".join(queries) nodes = self.post.rerank(nodes, joined_query, top_k=top_k) nodes = self.post.context_filter(nodes) context = self.post.compress_contexts(nodes) # Step 4: Generate answer = self.generator.generate(joined_query, context) # Step 5: Verification # The generate method handles internal verification based on self.generator.verify verified = self.generator.verify return { "question": question, "answer": answer, "verified": verified, "context": context, "num_docs": len(nodes), "expanded_queries": queries, "mode": "retrieval_augmented", "show_warning": False } # ----------------------------- # Prompt builder # ----------------------------- def build_yoruba_prompt(query: str, context: str = "", domain: str = "General") -> str: if context: prompt = f""" Ìwọ jẹ́ ọ̀jọ̀gbọ́n nínú {domain}. Fún ìdáhùn kan sí ìbéèrè yìí ní èdè Yorùbá nípa lílo àkíyèsí àwọn ìwé ìtọ́ni tí a fún ní ìsàlẹ̀. Àwọn ìwé ìtọ́ni (Context): {context} Ìbéèrè: {query} Jọwọ: 1. Dáhùn ní èdè Yorùbá 2. Bá ìbéèrè mu 3. Tó o jẹ́ òtító 4.Answer in clear Yoruba with proper paragraph spacing. Ìdáhùn: """ else: prompt = f""" Ìwọ jẹ́ ọ̀jọ̀gbọ́n nínú {domain}. Fún ìdáhùn kan sí ìbéèrè yìí ní èdè Yorùbá: Ìbéèrè: {query} Jọwọ: 1. Dáhùn ní èdè Yorùbá 2. Bá ìbéèrè mu 3. Tó o jẹ́ òtító 4. Answer in clear Yoruba with proper paragraph spacing. Ìdáhùn: """ return prompt # ----------------------------- # STARTUP: Chainlit handlers # ----------------------------- DOMAINS = ["Entertainment", "Current Affairs", "Social Life", "Culture", "Religion"] @cl.on_chat_start async def start(): """Load heavy resources lazily here. This reduces top-level startup time and avoids Docker startup timeouts. """ await cl.Message("🔧 Initializing application — this may take a few seconds...").send() # Prepare objects to store in user session state = {} # 1) Load embedding model lazily from llama_index.core import Settings try: embedder = AfriBERTaEmbedding() Settings.embed_model = embedder # Do not call embedder.load() synchronously to avoid long startup; but the first call will load it. state["embedder"] = Settings.embed_model except Exception as e: logger.warning(f"Embedding init failed: {e}") state["embedder"] = None # 2) Load or fallback LLM llm = None if GEMINI_API_KEY: Settings.llm = load_gemini_llm(GEMINI_API_KEY) else: logger.info("GEMINI_API_KEY not set. LLM will be None (fallback to safe messages).") # 3) Setup simple modules (no vector store unless weaviate configured) generation_module = GenerationModule(llm=Settings.llm, verify=True) # If weaviate is available, try to create a retriever retrieval_module = None try: if WEAVIATE_URL and WEAVIATE_API_KEY: import weaviate from llama_index.vector_stores.weaviate import WeaviateVectorStore from llama_index.core import StorageContext, VectorStoreIndex client = weaviate.connect_to_weaviate_cloud( cluster_url=WEAVIATE_URL, auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY), skip_init_checks=True, ) agent = QueryAgent(client=client, collections=[ QueryAgentCollectionConfig( name="Yoruba_rag", ), ], ) # Build minimal vector store wrapper — this assumes index already exists in Weaviate vector_store = WeaviateVectorStore(weaviate_client=client, index_name="YorubaChunk") storage_context = StorageContext.from_defaults(vector_store=vector_store) vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context, embed_model=embedder) #retriever = vector_index.as_retriever(similarity_top_k=5) retriever = WeaviateAgentRetriever(agent=agent) retrieval_module = RetrievalModule(dense_retriever=retriever,sparse_retriever=None) logger.info("Weaviate retriever initialized") else: logger.info("Weaviate not configured — continuing without a retriever.") except Exception as e: logger.warning(f"Weaviate initialization failed: {e}") retrieval_module = None # 4) Post module pre = PreRetrievalModule(llm=Settings.llm, enable_hyde=True) post_module = PostRetrievalModule(embed_model=state.get("embedder")) # 5) Orchestrator orchestrator = OrchestrationModule(pre=pre, retriever=retrieval_module, post=post_module, generator=generation_module) # Save into session cl.user_session.set("state", state) cl.user_session.set("orchestrator", orchestrator) cl.user_session.set("settings", {"domain": DOMAINS[0]}) # Send a friendly welcome with domain selector settings = await cl.ChatSettings( [ Select(id="domain", label="Select Domain (Ọ̀nà ìbéèrè)", values=DOMAINS, initial_index=0), ] ).send() # Welcome message in both Yoruba and English welcome_message = """ ## 🇳🇬 Pẹ̀lẹ́ o! | Welcome to the Yorùbá Question Answering Assistant This assistant helps you ask and receive answers in **Yorùbá**, powered by advanced artificial intelligence and curated knowledge sources. Please help us improve by evaluating this response. [👉 Click here to fill the form](https://forms.gle/owiYWgNgoeLtjr3N7) --- ### 📘 Bí o ṣe lè lo ìrànlọ́wọ́ yìí | How to Use This Assistant 1. 📁 **Yan apakan (Domain)** — Select a domain using the settings (⚙️). 2. 💬 **Beere ìbéèrè rẹ** — Ask your question in Yorùbá (or English if needed). 3. 🔍 **Gba ìdáhùn tó péye** — The system retrieves relevant information and generates a clear response. --- ### 🌍 Àwọn Apakan tí ó wà | Available Domains - 🎬 **Entertainment** — Movies, music, sports, and popular culture - 📰 **Current Affairs** — News, politics, economy - 👥 **Social Life** — Relationships, community, etiquette - 🎭 **Culture** — Traditions, history, festivals - 🙏 **Religion** — Beliefs, practices, spirituality --- ### ✅ Bẹrẹ nípa yíyan apakan kan, lẹ́yìn náà beere ìbéèrè rẹ ### Start by selecting a domain and asking your question """ await cl.Message(content=welcome_message).send() # ============================================================ @cl.on_settings_update async def load_thread(settings): pass # not needed yet @cl.on_message async def main(message: cl.Message): question = message.content.strip() if not question: await cl.Message(content="⚠️ Jọwọ beere ìbéèrè kan — ask a question.").send() return orchestrator: OrchestrationModule = cl.user_session.get("orchestrator") settings = cl.user_session.get("settings") or {"domain": DOMAINS[0]} domain = settings.get("domain", DOMAINS[0]) # Send an initial typing message and stream tokens msg = await cl.Message(content="🔄 Processing your question...").send() # If orchestrator not available, reply with fallback if orchestrator is None: await msg.update(content="⚠️ System not initialized. Try again shortly.") return # Run query (synchronous call inside async — if heavy, consider running in threadpool) try: result = await asyncio.to_thread(orchestrator.modular_query, question, domain, 20) except Exception as e: logger.exception("Query failed") await msg.update(content=f"⚠️ Query failed: {e}") return answer = result.get("answer", "⚠️ Ko ṣee ṣe lati gba ìdáhùn.") show_warming=result.get("show_warning", False) # Stream simple token chunks (word by word) import re tokens= re.findall(r'\S+\s*', answer) for token in tokens: await msg.stream_token(token + " ") await msg.update(content=answer) if show_warning: await cl.Message( content="⚠️ Ìkìlọ̀: Ìdáhùn yìí dá lórí ìmọ LLM nìkan (kò sí ìwé ìtọ́ni tí a lo)." ).send() # Send metadata mode = result.get("mode", "unknown") num_docs = result.get("num_docs", 0) if mode == "retrieval_augmented": await cl.Message(content=f"📚 Used {num_docs} documents for context.", author="system").send() else: await cl.Message(content="ℹ️ Answer produced from LLM fallback (no retrieved docs).").send() # ----------------------------- # If run as script, expose entrypoint name (useful for local testing) # ----------------------------- if __name__ == "__main__": print("This file is intended to be run with: chainlit run fixed_chainlit_app.py")