study-sathi / chain /qa_chain.py
YousifCreates's picture
updated qa_chain
b063f48
# chain/qa_chain.py
import os
import re
from dotenv import load_dotenv
from openai import OpenAI
from rag.retriever import retrieve, format_context
load_dotenv()
# ── Config ───────────────────────────────────────────────────────────────────
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
MODEL = "openai/gpt-oss-120b"
MAX_TOKENS = 2048
MAX_HISTORY = 10 # keep last 10 exchanges to avoid token overflow
# ── OpenRouter Client ─────────────────────────────────────────────────────────
client = OpenAI(
api_key = OPENROUTER_API_KEY,
base_url = "https://openrouter.ai/api/v1"
)
# ── System Prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are Study Saathi β€” a friendly and smart study assistant.
You help students understand their Operating Systems notes.
Rules you must follow:
- Answer ONLY from the provided context or conversation history. Never use outside knowledge.
- If the answer is not in the context or history, say: "Yeh topic notes mein nahi mila."
_ Whenever you are asked to explain the topic it means you have to explain the topic in a simple way. You can use examples, analogies, and simple language to make it easy to understand. DONOT USE BULLET POINTS WHENEVER EXPLIANING THE TOPIC. USE PLAIN TEXT TO EXPLAIN THE TOPIC.
- You have access to the full conversation history. Use it to answer follow-up
questions like "translate the above", "explain more", "give examples of that", etc.
- Explain in simple Roman Urdu, Urdu, or English β€” based on what the user uses.
- Use markdown formatting: headings, bullet points, bold, tables where helpful.
- Keep explanations clear and student-friendly.
- For MCQs: generate exactly the number asked, only from the provided context.
- For MCQ answer keys: always return them in a markdown table.
"""
# ── Follow-up detection ───────────────────────────────────────────────────────
FOLLOWUP_PATTERNS = [
r'\babove\b', r'\bupar\b', r'\bwoh\b', r'\bise\b', r'\busse\b',
r'\btranslate\b', r'\btarjuma\b', r'\bsummari[sz]e\b',
r'\bexplain more\b', r'\baur explain\b', r'\baur batao\b',
r'\bexpand\b', r'\brepeat\b', r'\bdobara\b', r'\bphir se\b',
r'\bsimplify\b', r'\brewrite\b', r'\bconvert\b',
r'\bthis\b', r'\bthat\b', r'\byeh\b', r'\bwahi\b',
r'\bprevious\b', r'\blast\b',
]
def is_followup(query: str) -> bool:
"""Returns True if the query is a follow-up that doesn't need RAG."""
q = query.lower()
return any(re.search(p, q) for p in FOLLOWUP_PATTERNS)
# ── Build Context Prompt ──────────────────────────────────────────────────────
def build_context_prompt(query: str, context: str) -> str:
return f"""Use the following context from the student's notes to help answer.
--- CONTEXT START ---
{context}
--- CONTEXT END ---
Student's Question: {query}
"""
def build_followup_prompt(query: str) -> str:
return f"""This is a follow-up question. Use the conversation history above to answer.
Do NOT search for new context β€” work only from what was already discussed.
Student's Follow-up: {query}
"""
# ── Detect MCQ Request ────────────────────────────────────────────────────────
def extract_mcq_count(query: str):
match = re.search(r'(\d+)\s*(mcq|question|mcqs|questions)', query.lower())
return int(match.group(1)) if match else None
# ── Trim history to avoid token overflow ─────────────────────────────────────
def trim_history(history: list) -> list:
max_msgs = MAX_HISTORY * 2
if len(history) > max_msgs:
return history[-max_msgs:]
return history
# ── Main Chain ────────────────────────────────────────────────────────────────
def run_chain(query: str, topic: str = None, history: list = []) -> str:
"""
Full RAG chain with conversation memory.
- Follow-up questions skip RAG and use history only.
- New questions retrieve chunks from Pinecone.
"""
mcq_count = extract_mcq_count(query)
followup = is_followup(query) and len(history) > 0
if followup:
# ── Follow-up: skip RAG, use history only ──────────────────────────
user_message = build_followup_prompt(query)
else:
# ── New question: retrieve from Pinecone ───────────────────────────
top_k = 10 if mcq_count else 5
chunks = retrieve(query, topic=topic, top_k=top_k)
if not chunks:
context = "No relevant context found in the notes."
else:
context = format_context(chunks)
user_message = build_context_prompt(query, context)
# inject MCQ instructions if needed
if mcq_count:
user_message += (
f"\n\nIMPORTANT: Generate exactly {mcq_count} MCQs from the context above. "
"Format each MCQ as:\n**Q1.** Question\n- A) option\n- B) option\n"
"- C) option\n- D) option\n\n"
"After all MCQs, provide the answer key in a markdown table "
"with columns: | Q# | Answer | Explanation |"
)
# ── Build messages: system + trimmed history + current user msg ────────
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
messages += trim_history(history[:-1]) # history minus the last user msg
messages.append({"role": "user", "content": user_message})
response = client.chat.completions.create(
model = MODEL,
messages = messages,
max_tokens = MAX_TOKENS,
)
return response.choices[0].message.content
# ── Quick Test ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
q1 = "Explain Process Registers in simple words"
r1 = run_chain(q1, topic="ch-01-updated", history=[])
print("=== Response 1 ===\n", r1)
fake_history = [
{"role": "user", "content": q1},
{"role": "assistant", "content": r1},
{"role": "user", "content": "Translate the above into Roman Urdu"},
]
q2 = "Translate the above into Roman Urdu"
r2 = run_chain(q2, topic="ch-01-updated", history=fake_history)
print("\n=== Response 2 (follow-up) ===\n", r2)