Spaces:
Sleeping
Sleeping
| """ | |
| RAG Chain | |
| ---------- | |
| Connects the hybrid retriever to Groq's LLM with a carefully | |
| engineered prompt for accurate, source-cited responses. | |
| Improvements over v1: | |
| - Streaming responses for real-time token output | |
| - Input guardrails to reject off-topic queries | |
| - Conversation memory summarization for long chats | |
| - Robust error handling with retries and fallbacks | |
| """ | |
| import os | |
| import re | |
| import time | |
| from typing import Optional, Generator | |
| from langchain_core.documents import Document | |
| from langchain_groq import ChatGroq | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from src.retriever import HybridRetriever | |
| # --------------------------------------------------------------------------- | |
| # SYSTEM PROMPT | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are the IJNet Assistant — a helpful, knowledgeable chatbot for the International Journalists' Network (IJNet). IJNet connects journalists worldwide with training opportunities, fellowships, grants, awards, tools, and expert guidance. | |
| YOUR ROLE: | |
| - Help journalists find relevant opportunities, resources, and information from IJNet's knowledge base. | |
| - Always ground your answers in the provided context. Do NOT make up opportunities, deadlines, or details. | |
| - Cite your sources clearly by referencing the opportunity/article title and the organizing body. | |
| - Do NOT include any URLs or links in your response. Sources are shown separately in the UI. | |
| - If the context doesn't contain enough information to answer, say so honestly and suggest the user visit ijnet.org for the latest information. | |
| RESPONSE GUIDELINES: | |
| - Be concise but thorough. Journalists are busy — get to the point. | |
| - When listing opportunities, include: title, organization, deadline, key benefits, and eligibility highlights. | |
| - When discussing articles/resources, summarize the key takeaways. | |
| - For deadline queries, clearly state which opportunities are still open and their exact deadlines. | |
| - If asked about topics not in the context, say "I don't have information about that in my current knowledge base. I recommend checking ijnet.org for the latest opportunities and resources." | |
| - Use a friendly, professional tone appropriate for an international audience. | |
| - At the end, remind users to visit https://ijnet.org/en/opportunities for the most up-to-date listings. | |
| FORMATTING: | |
| - Use bullet points for listing multiple opportunities. | |
| - Bold key details like deadlines and opportunity names. | |
| - Keep responses focused and scannable. | |
| Today's date is: {current_date} | |
| """ | |
| SUMMARY_PROMPT = """Summarize this conversation between a user and the IJNet Assistant in 2-3 sentences. | |
| Focus on what topics were discussed, what the user was looking for, and key information provided. | |
| Conversation: | |
| {conversation} | |
| Summary:""" | |
| # --------------------------------------------------------------------------- | |
| # GUARDRAILS | |
| # --------------------------------------------------------------------------- | |
| # Topics the IJNet assistant should handle | |
| ALLOWED_TOPICS = [ | |
| "journalism", "journalist", "media", "newsroom", "reporting", | |
| "fellowship", "grant", "award", "training", "opportunity", | |
| "ijnet", "icfj", "newsletter", "press", "editor", | |
| "investigation", "data journalism", "fact-check", "verification", | |
| "digital security", "ai tools", "mobile journalism", | |
| "freelance", "climate", "environment", "solutions journalism", | |
| "product design", "news product", "innovation", | |
| "africa", "asia", "europe", "latin america", "middle east", "mena", | |
| "deadline", "apply", "eligibility", "subscribe", | |
| "hello", "hi", "hey", "help", "thanks", "thank you", "what can you do", | |
| ] | |
| OFF_TOPIC_RESPONSE = ( | |
| "I'm the IJNet Assistant, and I'm specifically designed to help with " | |
| "journalism-related queries — like finding fellowships, grants, training " | |
| "programs, and resources for journalists. I can't help with that particular " | |
| "question, but I'd love to help you find journalism opportunities! " | |
| "Try asking something like:\n\n" | |
| "- *What fellowships are available for journalists in Africa?*\n" | |
| "- *What AI tools can journalists use?*\n" | |
| "- *Which IJNet newsletter should I subscribe to?*" | |
| ) | |
| def check_guardrails(query: str) -> tuple[bool, str]: | |
| """ | |
| Check if the query is within scope for the IJNet assistant. | |
| Order: short greetings → off-topic patterns → allowed keywords → short fallback → default allow. | |
| Off-topic runs before allowed-keywords so "translate hello world" is blocked | |
| even though "hello" is an allowed greeting keyword. | |
| Returns: | |
| (is_allowed, message) — if not allowed, message contains the rejection text. | |
| """ | |
| q_lower = query.lower().strip() | |
| # Allow very short queries (greetings, etc.) | |
| if len(q_lower) < 4: | |
| return True, "" | |
| # Check for clearly off-topic patterns FIRST | |
| off_topic_patterns = [ | |
| r"write\s+(me\s+)?(a\s+)?(poem|song|story|essay|code|script)", | |
| r"weather\s+(in|for|at|today|tomorrow|forecast)", | |
| r"what.{0,5}s\s+the\s+weather", | |
| r"(stock|price|score)\s+(of|for|today)", | |
| r"(cook|recipe|ingredient|bake)", | |
| r"(math|calcul|equation|solve\s+(this|the|for))", | |
| r"translate\s", | |
| r"translation\s", | |
| r"(play|game|quiz|trivia)\s", | |
| r"tell\s+(me\s+)?a\s+joke", | |
| r"(joke|funny|humor|riddle)", | |
| ] | |
| for pattern in off_topic_patterns: | |
| if re.search(pattern, q_lower): | |
| return False, OFF_TOPIC_RESPONSE | |
| # THEN check if any journalism-related keyword is present | |
| for keyword in ALLOWED_TOPICS: | |
| # For short keywords (<=3 chars), require word boundary match | |
| # to avoid "hi" matching inside "this", "equation", etc. | |
| if len(keyword) <= 3: | |
| if re.search(r'\b' + re.escape(keyword) + r'\b', q_lower): | |
| return True, "" | |
| else: | |
| if keyword in q_lower: | |
| return True, "" | |
| # Allow questions that seem like follow-ups (short, contextual) | |
| if len(q_lower.split()) <= 5: | |
| return True, "" | |
| # Default: allow (better to answer than to wrongly reject) | |
| return True, "" | |
| # --------------------------------------------------------------------------- | |
| # CONTEXT FORMATTER | |
| # --------------------------------------------------------------------------- | |
| def format_context(documents: list[Document]) -> str: | |
| """Format retrieved documents into a structured context string for the LLM.""" | |
| if not documents: | |
| return "No relevant documents found in the knowledge base." | |
| context_parts = [] | |
| for i, doc in enumerate(documents, 1): | |
| meta = doc.metadata | |
| source_type = meta.get("source_type", "unknown") | |
| title = meta.get("title", "Untitled") | |
| header_parts = [f"[Source {i}] {title}"] | |
| if source_type == "opportunity": | |
| header_parts.append(f"Type: {meta.get('opp_type', 'N/A')}") | |
| header_parts.append(f"Deadline: {meta.get('deadline', 'N/A')}") | |
| header_parts.append(f"Regions: {meta.get('regions', 'N/A')}") | |
| elif source_type == "article": | |
| header_parts.append(f"Author: {meta.get('author', 'N/A')}") | |
| header_parts.append(f"Date: {meta.get('date', 'N/A')}") | |
| header = " | ".join(header_parts) | |
| context_parts.append(f"{header}\n{doc.page_content}") | |
| return "\n\n---\n\n".join(context_parts) | |
| def format_sources(documents: list[Document]) -> list[dict]: | |
| """Extract source metadata for display in the UI.""" | |
| seen = set() | |
| sources = [] | |
| for doc in documents: | |
| doc_id = doc.metadata.get("doc_id", "") | |
| if doc_id in seen: | |
| continue | |
| seen.add(doc_id) | |
| source = { | |
| "title": doc.metadata.get("title", "Unknown"), | |
| "url": doc.metadata.get("source", ""), | |
| "type": doc.metadata.get("source_type", ""), | |
| } | |
| if doc.metadata.get("source_type") == "opportunity": | |
| source["deadline"] = doc.metadata.get("deadline", "") | |
| source["opp_type"] = doc.metadata.get("opp_type", "") | |
| source["organization"] = doc.metadata.get("organization", "") | |
| elif doc.metadata.get("source_type") == "article": | |
| source["author"] = doc.metadata.get("author", "") | |
| source["date"] = doc.metadata.get("date", "") | |
| sources.append(source) | |
| return sources | |
| # --------------------------------------------------------------------------- | |
| # RAG CHAIN | |
| # --------------------------------------------------------------------------- | |
| class IJNetRAGChain: | |
| """ | |
| End-to-end RAG chain: guardrails → retrieve → generate (stream) → cite. | |
| Supports streaming, multi-turn conversation with memory summarization, | |
| sidebar filters, and robust error handling. | |
| """ | |
| MAX_HISTORY_TURNS = 4 # Keep last N turn-pairs before summarizing | |
| MAX_RETRIES = 2 # Retry on transient errors | |
| RETRY_DELAY = 2 # Seconds between retries | |
| def __init__( | |
| self, | |
| retriever: HybridRetriever, | |
| groq_api_key: Optional[str] = None, | |
| model_name: str = "llama-3.3-70b-versatile", | |
| temperature: float = 0.1, | |
| ): | |
| self.retriever = retriever | |
| api_key = groq_api_key or os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "Groq API key required. Set GROQ_API_KEY environment variable " | |
| "or pass groq_api_key parameter. Get a free key at https://console.groq.com" | |
| ) | |
| self.llm = ChatGroq( | |
| model=model_name, | |
| api_key=api_key, | |
| temperature=temperature, | |
| max_tokens=1024, | |
| ) | |
| self.prompt = ChatPromptTemplate.from_messages([ | |
| ("system", SYSTEM_PROMPT), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "CONTEXT FROM IJNET KNOWLEDGE BASE:\n{context}\n\nUSER QUESTION: {question}"), | |
| ]) | |
| self.chat_history: list = [] | |
| self.conversation_summary: str = "" | |
| # ----- Memory Summarization ----- | |
| def _summarize_history(self): | |
| """ | |
| When conversation history exceeds MAX_HISTORY_TURNS, summarize older | |
| messages into a compact summary and keep only recent turns. | |
| """ | |
| if len(self.chat_history) <= self.MAX_HISTORY_TURNS * 2: | |
| return | |
| # Build conversation text from older messages | |
| old_messages = self.chat_history[:-(self.MAX_HISTORY_TURNS * 2)] | |
| conv_text = "" | |
| for msg in old_messages: | |
| role = "User" if isinstance(msg, HumanMessage) else "Assistant" | |
| conv_text += f"{role}: {msg.content[:200]}\n" | |
| try: | |
| summary_response = self.llm.invoke( | |
| SUMMARY_PROMPT.format(conversation=conv_text) | |
| ) | |
| self.conversation_summary = summary_response.content | |
| except Exception: | |
| # If summarization fails, just truncate | |
| self.conversation_summary = "" | |
| # Keep only recent turns | |
| self.chat_history = self.chat_history[-(self.MAX_HISTORY_TURNS * 2):] | |
| def _get_effective_history(self) -> list: | |
| """Get chat history with summary prepended if available.""" | |
| history = [] | |
| if self.conversation_summary: | |
| history.append(SystemMessage( | |
| content=f"Summary of earlier conversation: {self.conversation_summary}" | |
| )) | |
| history.extend(self.chat_history) | |
| return history | |
| # ----- Core Query Methods ----- | |
| def _retrieve_and_format( | |
| self, question: str, filters: Optional[dict] = None, include_debug: bool = False | |
| ) -> tuple[str, list[dict], Optional[dict]]: | |
| """Retrieve documents, apply optional filters, and format context.""" | |
| if include_debug: | |
| debug_info = self.retriever.retrieve_with_debug(question) | |
| retrieved_docs = debug_info["final_results"] | |
| else: | |
| retrieved_docs = self.retriever.retrieve(question) | |
| debug_info = None | |
| # Apply sidebar filters (post-retrieval boost) | |
| if filters: | |
| retrieved_docs = self._apply_ui_filters(retrieved_docs, filters) | |
| context = format_context(retrieved_docs) | |
| sources = format_sources(retrieved_docs) | |
| debug_out = None | |
| if include_debug and debug_info: | |
| debug_out = { | |
| "classification": debug_info["classification"], | |
| "num_retrieved": len(retrieved_docs), | |
| "semantic_top3": debug_info["semantic_results"][:3], | |
| "bm25_top3": debug_info["bm25_results"][:3], | |
| } | |
| return context, sources, debug_out | |
| def _apply_ui_filters(self, docs: list[Document], filters: dict) -> list[Document]: | |
| """Apply explicit UI sidebar filters to retrieved documents.""" | |
| filtered = docs | |
| if filters.get("region") and filters["region"] != "All": | |
| region = filters["region"].lower() | |
| # Boost matching docs to top, keep others as fallback | |
| matching = [d for d in filtered if region in d.metadata.get("regions", "").lower()] | |
| non_matching = [d for d in filtered if d not in matching] | |
| filtered = matching + non_matching | |
| if filters.get("opp_type") and filters["opp_type"] != "All": | |
| opp_type = filters["opp_type"].lower() | |
| matching = [d for d in filtered if d.metadata.get("opp_type", "").lower() == opp_type] | |
| non_matching = [d for d in filtered if d not in matching] | |
| filtered = matching + non_matching | |
| return filtered | |
| def _build_prompt_value(self, question: str, context: str): | |
| """Build the prompt with current date and history.""" | |
| from datetime import datetime | |
| return self.prompt.invoke({ | |
| "current_date": datetime.now().strftime("%B %d, %Y"), | |
| "chat_history": self._get_effective_history(), | |
| "context": context, | |
| "question": question, | |
| }) | |
| def query( | |
| self, | |
| question: str, | |
| filters: Optional[dict] = None, | |
| include_debug: bool = False, | |
| ) -> dict: | |
| """ | |
| Non-streaming query. Returns full response at once. | |
| Used as fallback if streaming fails. | |
| """ | |
| # Guardrails check | |
| is_allowed, rejection_msg = check_guardrails(question) | |
| if not is_allowed: | |
| return {"answer": rejection_msg, "sources": [], "guardrail_blocked": True} | |
| context, sources, debug_out = self._retrieve_and_format( | |
| question, filters, include_debug | |
| ) | |
| prompt_value = self._build_prompt_value(question, context) | |
| # Retry logic | |
| last_error = None | |
| for attempt in range(self.MAX_RETRIES + 1): | |
| try: | |
| response = self.llm.invoke(prompt_value) | |
| answer = response.content | |
| self.chat_history.append(HumanMessage(content=question)) | |
| self.chat_history.append(AIMessage(content=answer)) | |
| self._summarize_history() | |
| result = {"answer": answer, "sources": sources} | |
| if debug_out: | |
| result["debug"] = debug_out | |
| return result | |
| except Exception as e: | |
| last_error = e | |
| error_msg = str(e).lower() | |
| # Don't retry on auth errors | |
| if "api_key" in error_msg or "auth" in error_msg or "invalid" in error_msg: | |
| raise | |
| if attempt < self.MAX_RETRIES: | |
| time.sleep(self.RETRY_DELAY * (attempt + 1)) | |
| raise last_error | |
| def query_stream( | |
| self, | |
| question: str, | |
| filters: Optional[dict] = None, | |
| include_debug: bool = False, | |
| ) -> dict: | |
| """ | |
| Streaming query. Returns a dict where 'answer' is a generator | |
| that yields tokens, plus sources and debug info. | |
| Usage: | |
| result = chain.query_stream("...") | |
| for token in result["answer_stream"]: | |
| print(token, end="") | |
| # After stream completes, result["sources"] is available | |
| """ | |
| # Guardrails check | |
| is_allowed, rejection_msg = check_guardrails(question) | |
| if not is_allowed: | |
| return { | |
| "answer_stream": iter([rejection_msg]), | |
| "sources": [], | |
| "guardrail_blocked": True, | |
| } | |
| context, sources, debug_out = self._retrieve_and_format( | |
| question, filters, include_debug | |
| ) | |
| prompt_value = self._build_prompt_value(question, context) | |
| def token_generator() -> Generator[str, None, None]: | |
| full_response = [] | |
| last_error = None | |
| for attempt in range(self.MAX_RETRIES + 1): | |
| try: | |
| for chunk in self.llm.stream(prompt_value): | |
| token = chunk.content | |
| if token: | |
| full_response.append(token) | |
| yield token | |
| # After streaming completes, update history | |
| answer = "".join(full_response) | |
| self.chat_history.append(HumanMessage(content=question)) | |
| self.chat_history.append(AIMessage(content=answer)) | |
| self._summarize_history() | |
| return | |
| except Exception as e: | |
| last_error = e | |
| error_msg = str(e).lower() | |
| if "api_key" in error_msg or "auth" in error_msg: | |
| yield f"\n\n❌ Authentication error. Please check your API key." | |
| return | |
| if attempt < self.MAX_RETRIES: | |
| time.sleep(self.RETRY_DELAY * (attempt + 1)) | |
| full_response = [] # Reset for retry | |
| else: | |
| yield f"\n\n❌ Error after {self.MAX_RETRIES + 1} attempts: {last_error}" | |
| result = { | |
| "answer_stream": token_generator(), | |
| "sources": sources, | |
| } | |
| if debug_out: | |
| result["debug"] = debug_out | |
| return result | |
| def reset_history(self): | |
| """Clear conversation history and summary.""" | |
| self.chat_history = [] | |
| self.conversation_summary = "" | |