IJNet-assistant / src /chain.py
Mohammad Haris
Deploy IJNet assistant
b87aca1
Raw
History Blame Contribute Delete
19.1 kB
"""
RAG Chain
----------
Connects the hybrid retriever to Groq's LLM with a carefully
engineered prompt for accurate, source-cited responses.
Improvements over v1:
- Streaming responses for real-time token output
- Input guardrails to reject off-topic queries
- Conversation memory summarization for long chats
- Robust error handling with retries and fallbacks
"""
import os
import re
import time
from typing import Optional, Generator
from langchain_core.documents import Document
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from src.retriever import HybridRetriever
# ---------------------------------------------------------------------------
# SYSTEM PROMPT
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are the IJNet Assistant — a helpful, knowledgeable chatbot for the International Journalists' Network (IJNet). IJNet connects journalists worldwide with training opportunities, fellowships, grants, awards, tools, and expert guidance.
YOUR ROLE:
- Help journalists find relevant opportunities, resources, and information from IJNet's knowledge base.
- Always ground your answers in the provided context. Do NOT make up opportunities, deadlines, or details.
- Cite your sources clearly by referencing the opportunity/article title and the organizing body.
- Do NOT include any URLs or links in your response. Sources are shown separately in the UI.
- If the context doesn't contain enough information to answer, say so honestly and suggest the user visit ijnet.org for the latest information.
RESPONSE GUIDELINES:
- Be concise but thorough. Journalists are busy — get to the point.
- When listing opportunities, include: title, organization, deadline, key benefits, and eligibility highlights.
- When discussing articles/resources, summarize the key takeaways.
- For deadline queries, clearly state which opportunities are still open and their exact deadlines.
- If asked about topics not in the context, say "I don't have information about that in my current knowledge base. I recommend checking ijnet.org for the latest opportunities and resources."
- Use a friendly, professional tone appropriate for an international audience.
- At the end, remind users to visit https://ijnet.org/en/opportunities for the most up-to-date listings.
FORMATTING:
- Use bullet points for listing multiple opportunities.
- Bold key details like deadlines and opportunity names.
- Keep responses focused and scannable.
Today's date is: {current_date}
"""
SUMMARY_PROMPT = """Summarize this conversation between a user and the IJNet Assistant in 2-3 sentences.
Focus on what topics were discussed, what the user was looking for, and key information provided.
Conversation:
{conversation}
Summary:"""
# ---------------------------------------------------------------------------
# GUARDRAILS
# ---------------------------------------------------------------------------
# Topics the IJNet assistant should handle
ALLOWED_TOPICS = [
"journalism", "journalist", "media", "newsroom", "reporting",
"fellowship", "grant", "award", "training", "opportunity",
"ijnet", "icfj", "newsletter", "press", "editor",
"investigation", "data journalism", "fact-check", "verification",
"digital security", "ai tools", "mobile journalism",
"freelance", "climate", "environment", "solutions journalism",
"product design", "news product", "innovation",
"africa", "asia", "europe", "latin america", "middle east", "mena",
"deadline", "apply", "eligibility", "subscribe",
"hello", "hi", "hey", "help", "thanks", "thank you", "what can you do",
]
OFF_TOPIC_RESPONSE = (
"I'm the IJNet Assistant, and I'm specifically designed to help with "
"journalism-related queries — like finding fellowships, grants, training "
"programs, and resources for journalists. I can't help with that particular "
"question, but I'd love to help you find journalism opportunities! "
"Try asking something like:\n\n"
"- *What fellowships are available for journalists in Africa?*\n"
"- *What AI tools can journalists use?*\n"
"- *Which IJNet newsletter should I subscribe to?*"
)
def check_guardrails(query: str) -> tuple[bool, str]:
"""
Check if the query is within scope for the IJNet assistant.
Order: short greetings → off-topic patterns → allowed keywords → short fallback → default allow.
Off-topic runs before allowed-keywords so "translate hello world" is blocked
even though "hello" is an allowed greeting keyword.
Returns:
(is_allowed, message) — if not allowed, message contains the rejection text.
"""
q_lower = query.lower().strip()
# Allow very short queries (greetings, etc.)
if len(q_lower) < 4:
return True, ""
# Check for clearly off-topic patterns FIRST
off_topic_patterns = [
r"write\s+(me\s+)?(a\s+)?(poem|song|story|essay|code|script)",
r"weather\s+(in|for|at|today|tomorrow|forecast)",
r"what.{0,5}s\s+the\s+weather",
r"(stock|price|score)\s+(of|for|today)",
r"(cook|recipe|ingredient|bake)",
r"(math|calcul|equation|solve\s+(this|the|for))",
r"translate\s",
r"translation\s",
r"(play|game|quiz|trivia)\s",
r"tell\s+(me\s+)?a\s+joke",
r"(joke|funny|humor|riddle)",
]
for pattern in off_topic_patterns:
if re.search(pattern, q_lower):
return False, OFF_TOPIC_RESPONSE
# THEN check if any journalism-related keyword is present
for keyword in ALLOWED_TOPICS:
# For short keywords (<=3 chars), require word boundary match
# to avoid "hi" matching inside "this", "equation", etc.
if len(keyword) <= 3:
if re.search(r'\b' + re.escape(keyword) + r'\b', q_lower):
return True, ""
else:
if keyword in q_lower:
return True, ""
# Allow questions that seem like follow-ups (short, contextual)
if len(q_lower.split()) <= 5:
return True, ""
# Default: allow (better to answer than to wrongly reject)
return True, ""
# ---------------------------------------------------------------------------
# CONTEXT FORMATTER
# ---------------------------------------------------------------------------
def format_context(documents: list[Document]) -> str:
"""Format retrieved documents into a structured context string for the LLM."""
if not documents:
return "No relevant documents found in the knowledge base."
context_parts = []
for i, doc in enumerate(documents, 1):
meta = doc.metadata
source_type = meta.get("source_type", "unknown")
title = meta.get("title", "Untitled")
header_parts = [f"[Source {i}] {title}"]
if source_type == "opportunity":
header_parts.append(f"Type: {meta.get('opp_type', 'N/A')}")
header_parts.append(f"Deadline: {meta.get('deadline', 'N/A')}")
header_parts.append(f"Regions: {meta.get('regions', 'N/A')}")
elif source_type == "article":
header_parts.append(f"Author: {meta.get('author', 'N/A')}")
header_parts.append(f"Date: {meta.get('date', 'N/A')}")
header = " | ".join(header_parts)
context_parts.append(f"{header}\n{doc.page_content}")
return "\n\n---\n\n".join(context_parts)
def format_sources(documents: list[Document]) -> list[dict]:
"""Extract source metadata for display in the UI."""
seen = set()
sources = []
for doc in documents:
doc_id = doc.metadata.get("doc_id", "")
if doc_id in seen:
continue
seen.add(doc_id)
source = {
"title": doc.metadata.get("title", "Unknown"),
"url": doc.metadata.get("source", ""),
"type": doc.metadata.get("source_type", ""),
}
if doc.metadata.get("source_type") == "opportunity":
source["deadline"] = doc.metadata.get("deadline", "")
source["opp_type"] = doc.metadata.get("opp_type", "")
source["organization"] = doc.metadata.get("organization", "")
elif doc.metadata.get("source_type") == "article":
source["author"] = doc.metadata.get("author", "")
source["date"] = doc.metadata.get("date", "")
sources.append(source)
return sources
# ---------------------------------------------------------------------------
# RAG CHAIN
# ---------------------------------------------------------------------------
class IJNetRAGChain:
"""
End-to-end RAG chain: guardrails → retrieve → generate (stream) → cite.
Supports streaming, multi-turn conversation with memory summarization,
sidebar filters, and robust error handling.
"""
MAX_HISTORY_TURNS = 4 # Keep last N turn-pairs before summarizing
MAX_RETRIES = 2 # Retry on transient errors
RETRY_DELAY = 2 # Seconds between retries
def __init__(
self,
retriever: HybridRetriever,
groq_api_key: Optional[str] = None,
model_name: str = "llama-3.3-70b-versatile",
temperature: float = 0.1,
):
self.retriever = retriever
api_key = groq_api_key or os.environ.get("GROQ_API_KEY")
if not api_key:
raise ValueError(
"Groq API key required. Set GROQ_API_KEY environment variable "
"or pass groq_api_key parameter. Get a free key at https://console.groq.com"
)
self.llm = ChatGroq(
model=model_name,
api_key=api_key,
temperature=temperature,
max_tokens=1024,
)
self.prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="chat_history"),
("human", "CONTEXT FROM IJNET KNOWLEDGE BASE:\n{context}\n\nUSER QUESTION: {question}"),
])
self.chat_history: list = []
self.conversation_summary: str = ""
# ----- Memory Summarization -----
def _summarize_history(self):
"""
When conversation history exceeds MAX_HISTORY_TURNS, summarize older
messages into a compact summary and keep only recent turns.
"""
if len(self.chat_history) <= self.MAX_HISTORY_TURNS * 2:
return
# Build conversation text from older messages
old_messages = self.chat_history[:-(self.MAX_HISTORY_TURNS * 2)]
conv_text = ""
for msg in old_messages:
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
conv_text += f"{role}: {msg.content[:200]}\n"
try:
summary_response = self.llm.invoke(
SUMMARY_PROMPT.format(conversation=conv_text)
)
self.conversation_summary = summary_response.content
except Exception:
# If summarization fails, just truncate
self.conversation_summary = ""
# Keep only recent turns
self.chat_history = self.chat_history[-(self.MAX_HISTORY_TURNS * 2):]
def _get_effective_history(self) -> list:
"""Get chat history with summary prepended if available."""
history = []
if self.conversation_summary:
history.append(SystemMessage(
content=f"Summary of earlier conversation: {self.conversation_summary}"
))
history.extend(self.chat_history)
return history
# ----- Core Query Methods -----
def _retrieve_and_format(
self, question: str, filters: Optional[dict] = None, include_debug: bool = False
) -> tuple[str, list[dict], Optional[dict]]:
"""Retrieve documents, apply optional filters, and format context."""
if include_debug:
debug_info = self.retriever.retrieve_with_debug(question)
retrieved_docs = debug_info["final_results"]
else:
retrieved_docs = self.retriever.retrieve(question)
debug_info = None
# Apply sidebar filters (post-retrieval boost)
if filters:
retrieved_docs = self._apply_ui_filters(retrieved_docs, filters)
context = format_context(retrieved_docs)
sources = format_sources(retrieved_docs)
debug_out = None
if include_debug and debug_info:
debug_out = {
"classification": debug_info["classification"],
"num_retrieved": len(retrieved_docs),
"semantic_top3": debug_info["semantic_results"][:3],
"bm25_top3": debug_info["bm25_results"][:3],
}
return context, sources, debug_out
def _apply_ui_filters(self, docs: list[Document], filters: dict) -> list[Document]:
"""Apply explicit UI sidebar filters to retrieved documents."""
filtered = docs
if filters.get("region") and filters["region"] != "All":
region = filters["region"].lower()
# Boost matching docs to top, keep others as fallback
matching = [d for d in filtered if region in d.metadata.get("regions", "").lower()]
non_matching = [d for d in filtered if d not in matching]
filtered = matching + non_matching
if filters.get("opp_type") and filters["opp_type"] != "All":
opp_type = filters["opp_type"].lower()
matching = [d for d in filtered if d.metadata.get("opp_type", "").lower() == opp_type]
non_matching = [d for d in filtered if d not in matching]
filtered = matching + non_matching
return filtered
def _build_prompt_value(self, question: str, context: str):
"""Build the prompt with current date and history."""
from datetime import datetime
return self.prompt.invoke({
"current_date": datetime.now().strftime("%B %d, %Y"),
"chat_history": self._get_effective_history(),
"context": context,
"question": question,
})
def query(
self,
question: str,
filters: Optional[dict] = None,
include_debug: bool = False,
) -> dict:
"""
Non-streaming query. Returns full response at once.
Used as fallback if streaming fails.
"""
# Guardrails check
is_allowed, rejection_msg = check_guardrails(question)
if not is_allowed:
return {"answer": rejection_msg, "sources": [], "guardrail_blocked": True}
context, sources, debug_out = self._retrieve_and_format(
question, filters, include_debug
)
prompt_value = self._build_prompt_value(question, context)
# Retry logic
last_error = None
for attempt in range(self.MAX_RETRIES + 1):
try:
response = self.llm.invoke(prompt_value)
answer = response.content
self.chat_history.append(HumanMessage(content=question))
self.chat_history.append(AIMessage(content=answer))
self._summarize_history()
result = {"answer": answer, "sources": sources}
if debug_out:
result["debug"] = debug_out
return result
except Exception as e:
last_error = e
error_msg = str(e).lower()
# Don't retry on auth errors
if "api_key" in error_msg or "auth" in error_msg or "invalid" in error_msg:
raise
if attempt < self.MAX_RETRIES:
time.sleep(self.RETRY_DELAY * (attempt + 1))
raise last_error
def query_stream(
self,
question: str,
filters: Optional[dict] = None,
include_debug: bool = False,
) -> dict:
"""
Streaming query. Returns a dict where 'answer' is a generator
that yields tokens, plus sources and debug info.
Usage:
result = chain.query_stream("...")
for token in result["answer_stream"]:
print(token, end="")
# After stream completes, result["sources"] is available
"""
# Guardrails check
is_allowed, rejection_msg = check_guardrails(question)
if not is_allowed:
return {
"answer_stream": iter([rejection_msg]),
"sources": [],
"guardrail_blocked": True,
}
context, sources, debug_out = self._retrieve_and_format(
question, filters, include_debug
)
prompt_value = self._build_prompt_value(question, context)
def token_generator() -> Generator[str, None, None]:
full_response = []
last_error = None
for attempt in range(self.MAX_RETRIES + 1):
try:
for chunk in self.llm.stream(prompt_value):
token = chunk.content
if token:
full_response.append(token)
yield token
# After streaming completes, update history
answer = "".join(full_response)
self.chat_history.append(HumanMessage(content=question))
self.chat_history.append(AIMessage(content=answer))
self._summarize_history()
return
except Exception as e:
last_error = e
error_msg = str(e).lower()
if "api_key" in error_msg or "auth" in error_msg:
yield f"\n\n❌ Authentication error. Please check your API key."
return
if attempt < self.MAX_RETRIES:
time.sleep(self.RETRY_DELAY * (attempt + 1))
full_response = [] # Reset for retry
else:
yield f"\n\n❌ Error after {self.MAX_RETRIES + 1} attempts: {last_error}"
result = {
"answer_stream": token_generator(),
"sources": sources,
}
if debug_out:
result["debug"] = debug_out
return result
def reset_history(self):
"""Clear conversation history and summary."""
self.chat_history = []
self.conversation_summary = ""