Krish-Mind / rag_utils.py
GitHub Copilot
Deploy Docker Space package (no binary assets)
a654c7f
from __future__ import annotations
import json
import math
import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Any
import numpy as np
from sentence_transformers import SentenceTransformer
DEFAULT_DATA_FILE = os.path.join(os.path.dirname(__file__), "data", "krce_college_data.jsonl")
DEFAULT_EMBEDDING_MODEL = "all-MiniLM-L6-v2"
ABSTAIN_MESSAGE = "I don't know from the KRCE knowledge base."
# Keep this simple: only a minimal relevance threshold.
MIN_CONFIDENCE = 0.25
TOP_K = 3
SEARCH_STOPWORDS = {
"a", "an", "and", "are", "at", "be", "for", "from", "how", "in", "is", "it", "of", "on", "or",
"the", "to", "what", "when", "where", "who", "with", "your", "please", "tell", "me", "about",
}
# Lightweight post-generation safety net.
HALLUCINATION_MARKERS = (
"created by",
"created independently",
"created after leaving",
"des created me",
"i was created",
"krish cs my creator",
"my creator",
"my founder",
)
GENERAL_KNOWLEDGE_MARKERS = (
"algorithm",
"array",
"binary tree",
"coding",
"computer science",
"data structure",
"debug",
"explain",
"merge sort",
"python",
"quick sort",
"sorting",
"stack",
)
LIST_QUERY_MARKERS = (
"all",
"boys",
"faculty",
"faculties",
"girls",
"list",
"members",
"restroom",
"restrooms",
"staff",
"staffs",
"washroom",
"washrooms",
"who are",
)
TRAILING_QUERY_NOISE_MARKERS = (
", tell me about ",
", who are ",
", who is ",
", how many ",
", i m a cse student",
", i am a cse student",
", is dr ",
", krce cse",
", my hod if",
)
NAME_PATTERN = re.compile(r"\b(?:Dr|Mr|Mrs|Ms)\.\s*[A-Za-z][A-Za-z\s.]{1,70}")
@dataclass(frozen=True)
class RagIndex:
model: SentenceTransformer | None
records: list[dict[str, str]]
documents: list[str]
embeddings: np.ndarray | None
tokenized_documents: list[list[str]]
idf: dict[str, float]
def normalize_text(text: str) -> str:
text = text.lower().replace("'", " ").replace("/", " ").replace("-", " ")
text = re.sub(r"[^a-z0-9\s.]+", " ", text)
text = text.replace(".", " ")
return re.sub(r"\s+", " ", text).strip()
def _tokenize_for_search(text: str) -> list[str]:
normalized = normalize_text(text)
tokens = [token for token in normalized.split() if token and token not in SEARCH_STOPWORDS]
return tokens
def _build_idf(tokenized_documents: list[list[str]]) -> dict[str, float]:
if not tokenized_documents:
return {}
doc_freq: dict[str, int] = {}
total_docs = len(tokenized_documents)
for tokens in tokenized_documents:
unique_tokens = set(tokens)
for token in unique_tokens:
doc_freq[token] = doc_freq.get(token, 0) + 1
idf: dict[str, float] = {}
for token, freq in doc_freq.items():
idf[token] = math.log((total_docs + 1.0) / (freq + 1.0)) + 1.0
return idf
def _lexical_score(query_tokens: list[str], doc_tokens: list[str], idf: dict[str, float]) -> float:
if not query_tokens or not doc_tokens:
return 0.0
doc_set = set(doc_tokens)
weighted_overlap = sum(idf.get(token, 1.0) for token in query_tokens if token in doc_set)
weighted_total = sum(idf.get(token, 1.0) for token in query_tokens)
if weighted_total <= 0:
return 0.0
return weighted_overlap / weighted_total
def _clean_output_text(output: str) -> str:
cleaned = output.strip()
lowered = cleaned.lower()
cut_positions = []
for marker in TRAILING_QUERY_NOISE_MARKERS:
pos = lowered.find(marker)
if pos != -1:
cut_positions.append(pos)
if cut_positions:
cleaned = cleaned[: min(cut_positions)].rstrip(" ,;")
return cleaned
def is_krce_scope_query(query: str) -> bool:
normalized = normalize_text(query)
# Minimal scope check to decide when to force abstain on low confidence.
krce_terms = (
"krce",
"k ramakrishnan",
"college",
"department",
"faculty",
"hod",
"principal",
"professor",
"cse",
"ece",
"eee",
"ai ds",
"aids",
"csbs",
)
return any(term in normalized for term in krce_terms)
def classify_query_route(query: str) -> str:
normalized = normalize_text(query)
krce_scope = is_krce_scope_query(query)
general_scope = any(marker in normalized for marker in GENERAL_KNOWLEDGE_MARKERS)
if krce_scope and general_scope:
return "hybrid"
if krce_scope:
return "krce"
return "general"
def _load_records(data_file: str) -> list[dict[str, str]]:
records: list[dict[str, str]] = []
with open(data_file, "r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
try:
item = json.loads(line)
except json.JSONDecodeError:
continue
instruction = str(item.get("instruction", "")).strip()
output = _clean_output_text(str(item.get("output", "")))
if not instruction and not output:
continue
records.append(
{
"instruction": instruction,
"output": output,
}
)
return records
@lru_cache(maxsize=2)
def load_rag_index(data_file: str = DEFAULT_DATA_FILE, embedding_model: str = DEFAULT_EMBEDDING_MODEL) -> RagIndex:
if not os.path.exists(data_file):
return RagIndex(model=None, records=[], documents=[], embeddings=None, tokenized_documents=[], idf={})
try:
model = SentenceTransformer(embedding_model)
except Exception:
return RagIndex(model=None, records=[], documents=[], embeddings=None, tokenized_documents=[], idf={})
records = _load_records(data_file)
documents = [f"{record['instruction']}\n{record['output']}".strip() for record in records]
if documents:
embeddings = model.encode(documents, normalize_embeddings=True, convert_to_numpy=True)
else:
embeddings = np.empty((0, 0), dtype=np.float32)
tokenized_documents = [_tokenize_for_search(doc) for doc in documents]
idf = _build_idf(tokenized_documents)
return RagIndex(
model=model,
records=records,
documents=documents,
embeddings=embeddings,
tokenized_documents=tokenized_documents,
idf=idf,
)
def search_krce(query: str, rag_index: RagIndex, top_k: int = TOP_K) -> dict[str, Any]:
if rag_index.model is None or rag_index.embeddings is None or not rag_index.records:
return {
"query": query,
"context": "",
"hits": [],
"confidence": 0.0,
"should_abstain": True,
"abstain_reason": "RAG index is unavailable.",
}
query_embedding = rag_index.model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0]
vector_scores = np.dot(rag_index.embeddings, query_embedding).astype(float)
query_tokens = _tokenize_for_search(query)
lexical_scores = np.array(
[_lexical_score(query_tokens, doc_tokens, rag_index.idf) for doc_tokens in rag_index.tokenized_documents],
dtype=float,
)
# Hybrid ranking: dense similarity for semantics + lexical overlap for exact KRCE entities.
scores = (0.78 * vector_scores) + (0.22 * lexical_scores)
if scores.size == 0:
return {
"query": query,
"context": "",
"hits": [],
"confidence": 0.0,
"should_abstain": True,
"abstain_reason": ABSTAIN_MESSAGE,
}
ranked_indices = scores.argsort()[::-1]
best_score = float(scores[ranked_indices[0]])
if best_score < MIN_CONFIDENCE:
return {
"query": query,
"context": "",
"hits": [],
"confidence": best_score,
"should_abstain": True,
"abstain_reason": ABSTAIN_MESSAGE,
}
selected_indices = ranked_indices[: max(top_k, 5)]
hits: list[dict[str, Any]] = []
blocks: list[str] = []
for rank, idx in enumerate(selected_indices, start=1):
score = float(scores[idx])
vector_score = float(vector_scores[idx])
lexical_score = float(lexical_scores[idx])
record = rag_index.records[int(idx)]
hits.append(
{
"rank": rank,
"instruction": record["instruction"],
"output": record["output"],
"combined_score": score,
"vector_score": vector_score,
"lexical_score": lexical_score,
"specific_overlap": 0.0,
"role_overlap": 0.0,
}
)
blocks.append(
f"[KB-{rank} | score={score:.3f}]\n"
f"Question: {record['instruction']}\n"
f"Answer: {record['output']}"
)
return {
"query": query,
"context": "\n\n".join(blocks),
"hits": hits,
"confidence": best_score,
"should_abstain": False,
"abstain_reason": "",
}
def build_system_prompt(now: str, query: str, rag_result: dict[str, Any] | None) -> str:
prompt = (
f"You are Krish Mind, a grounded assistant for KRCE.\n"
f"CURRENT TIME: {now}\n\n"
"RULES:\n"
"- For KRCE facts, answer only from the KRCE evidence block.\n"
"- Synthesize the final answer in your own words; do not copy long raw blocks.\n"
"- Remove duplicates and repeated names.\n"
"- For list-style queries, return a clean bullet list.\n"
"- If the evidence does not directly answer, reply exactly: I don't know from the KRCE knowledge base.\n"
"- Do not invent people, roles, creator/founder claims, or hidden details.\n"
"- Keep the answer short and factual.\n"
)
if rag_result and rag_result.get("context"):
prompt += (
f"\n[KRCE EVIDENCE]\n{rag_result['context']}\n[END KRCE EVIDENCE]\n"
"Use this evidence only."
)
else:
prompt += "\nNo KRCE evidence was retrieved."
return prompt
def build_general_system_prompt(now: str) -> str:
return (
f"You are Krish Mind, a helpful AI assistant.\n"
f"CURRENT TIME: {now}\n\n"
"RULES:\n"
"- Answer clearly and accurately using your own knowledge.\n"
"- Keep replies compact by default (typically 4-10 lines unless user asks for full detail).\n"
"- Use clean Markdown: short paragraphs, bullets for lists, fenced code blocks for code.\n"
"- Avoid very long single lines; wrap explanations into readable short lines.\n"
"- Do not mention creator/founder identity unless the user explicitly asks about it.\n"
"- Do not claim personal origin stories that are not asked by the user.\n"
"- Keep answers concise and structured.\n"
)
def build_hybrid_system_prompt(now: str, rag_result: dict[str, Any] | None) -> str:
prompt = (
f"You are Krish Mind, a helpful AI assistant for KRCE-related questions.\n"
f"CURRENT TIME: {now}\n\n"
"RULES:\n"
"- Use KRCE evidence when available for college-specific facts.\n"
"- For general explanation details not present in KRCE evidence, use your own knowledge.\n"
"- Do not invent creator/founder identity claims.\n"
)
if rag_result and rag_result.get("context"):
prompt += f"\n[KRCE EVIDENCE]\n{rag_result['context']}\n[END KRCE EVIDENCE]\n"
return prompt
def looks_like_hallucinated_identity_claim(text: str) -> bool:
normalized = normalize_text(text)
return any(marker in normalized for marker in HALLUCINATION_MARKERS)
def _contains_code_content(text: str) -> bool:
lowered = text.lower()
if "```" in text:
return True
code_markers = (
"def ",
"class ",
"#include",
"public static void main",
"void ",
"int main",
)
return any(marker in lowered for marker in code_markers)
def _remove_identity_lines(text: str) -> str:
lines = text.splitlines()
kept = []
for line in lines:
if looks_like_hallucinated_identity_claim(line):
continue
kept.append(line)
cleaned = "\n".join(kept).strip()
return cleaned
def _is_generic_self_intro(text: str) -> bool:
normalized = normalize_text(text)
if not normalized:
return False
intro_prefixes = (
"i am krish mind",
"i m krish mind",
"hello i am krish mind",
"hi i am krish mind",
)
return any(normalized.startswith(prefix) for prefix in intro_prefixes)
def is_generic_self_intro(text: str) -> bool:
return _is_generic_self_intro(text)
def is_intro_or_identity_query(query: str) -> bool:
normalized = normalize_text(query)
intro_markers = (
"hi",
"hello",
"hey",
"good morning",
"good afternoon",
"good evening",
"who are you",
"introduce yourself",
"your name",
"tell me about yourself",
)
return any(marker in normalized for marker in intro_markers)
def _extract_people_names(text: str) -> list[str]:
found = NAME_PATTERN.findall(text)
cleaned: list[str] = []
seen = set()
for item in found:
name = re.sub(r"\s+", " ", item).strip(" ,.;")
name = re.sub(r"\s+(at|in)\s+krce\b", "", name, flags=re.IGNORECASE)
name = re.sub(r"\s+in\s+(cse|ece|eee|it|csbs|aids)\b", "", name, flags=re.IGNORECASE)
name = re.sub(r"\.(\s*(professors?|labs?|department).*)$", "", name, flags=re.IGNORECASE)
name = name.strip(" ,.;")
key = normalize_text(name)
if len(name) < 6:
continue
if any(bad in key for bad in ("professor", "lab", "department", "krce", "tell me", "who are")):
continue
if "tell me about" in key or "who are" in key:
continue
if key in seen:
continue
seen.add(key)
cleaned.append(name)
return cleaned
def build_deterministic_krce_answer(query: str, rag_result: dict[str, Any]) -> str:
normalized_query = normalize_text(query)
location_intent = ("where" in normalized_query and "department" in normalized_query)
list_intent = any(marker in normalized_query for marker in ("staff", "staffs", "faculty", "members", "list"))
factual_direct_intent = any(
token in normalized_query
for token in (
"who is",
"principal",
"chairman",
"vice principal",
"controller of examinations",
"deputy controller",
"hod",
"coordinator",
"contact",
"email",
"working hours",
"bus",
"attendance",
"mobile phone",
"dress code",
)
)
if not list_intent and not location_intent and not factual_direct_intent:
return ""
hits = rag_result.get("hits") or []
if not hits:
return ""
department_key = ""
for dep in ("cse", "ece", "eee", "it", "csbs", "ai ds", "aids"):
if re.search(rf"\b{re.escape(dep)}\b", normalized_query):
department_key = dep
break
filtered_hits = hits
if department_key:
scoped_hits = []
for hit in hits:
merged = f"{hit.get('instruction', '')} {hit.get('output', '')}"
if re.search(rf"\b{re.escape(department_key)}\b", normalize_text(merged)):
scoped_hits.append(hit)
if scoped_hits:
filtered_hits = scoped_hits
if factual_direct_intent and not list_intent and not location_intent:
if filtered_hits:
first = str(filtered_hits[0].get("output", "")).strip()
if first:
return first
if location_intent:
floor_pattern = re.compile(r"\b(ground|first|second|third|fourth|fifth)\s+floor\b", re.IGNORECASE)
for hit in filtered_hits:
output = str(hit.get("output", ""))
floor_match = floor_pattern.search(output)
if floor_match:
sentence = output.strip().split(".")[0].strip()
if sentence:
return sentence + "."
all_names: list[str] = []
seen = set()
for hit in filtered_hits:
output = str(hit.get("output", ""))
for name in _extract_people_names(output):
key = normalize_text(name)
if key in seen:
continue
seen.add(key)
all_names.append(name)
if not all_names:
return ""
if re.search(r"\b(male|boys|boy)\b", normalized_query):
filtered = [name for name in all_names if name.startswith(("Mr.",))]
if filtered:
all_names = filtered
elif re.search(r"\b(female|girls|girl)\b", normalized_query):
filtered = [name for name in all_names if name.startswith(("Mrs.", "Ms."))]
if filtered:
all_names = filtered
department = ""
for dep in ("cse", "ece", "eee", "it", "csbs", "ai ds", "aids"):
if dep in normalized_query:
department = dep.upper()
break
heading = f"{department} staff list:" if department else "Staff list:"
bullet_lines = "\n".join(f"- {name}" for name in all_names[:60])
return f"{heading}\n{bullet_lines}"
def compose_krce_response(query: str, rag_result: dict[str, Any]) -> str:
hits = rag_result.get("hits") or []
if not hits:
return ABSTAIN_MESSAGE
normalized_query = normalize_text(query)
is_list_query = any(marker in normalized_query for marker in LIST_QUERY_MARKERS)
if not is_list_query:
return str(hits[0].get("output", "")).strip() or ABSTAIN_MESSAGE
unique_outputs: list[str] = []
seen = set()
for hit in hits:
output = str(hit.get("output", "")).strip()
if not output:
continue
key = normalize_text(output)
if key in seen:
continue
seen.add(key)
unique_outputs.append(output)
if not unique_outputs:
return ABSTAIN_MESSAGE
if len(unique_outputs) == 1:
return unique_outputs[0]
return "\n".join(f"- {line}" for line in unique_outputs)
def finalize_krce_response(query: str, response_text: str, rag_result: dict[str, Any] | None) -> str:
if not response_text:
return ABSTAIN_MESSAGE if is_krce_scope_query(query) else response_text
if is_krce_scope_query(query):
if looks_like_hallucinated_identity_claim(response_text):
return ABSTAIN_MESSAGE
if rag_result and rag_result.get("should_abstain"):
return ABSTAIN_MESSAGE
return response_text
def finalize_general_response(query: str, response_text: str) -> str:
if not response_text:
return response_text
normalized_query = normalize_text(query)
identity_query = any(token in normalized_query for token in ("who created", "creator", "founder", "who are you"))
intro_query = is_intro_or_identity_query(query)
if identity_query:
return response_text
if intro_query:
return response_text
# For code answers, do not aggressively trim the full response.
if _contains_code_content(response_text):
cleaned_code_answer = _remove_identity_lines(response_text)
return cleaned_code_answer or response_text
if looks_like_hallucinated_identity_claim(response_text):
cleaned = response_text
lowered = normalize_text(response_text)
cut_positions = [lowered.find(marker) for marker in HALLUCINATION_MARKERS if lowered.find(marker) != -1]
if cut_positions:
cut = min(cut_positions)
cleaned = response_text[:cut].rstrip(" ,.;")
if cleaned:
return cleaned
return "I can help with this topic. Please ask the question directly and I will answer clearly."
return response_text
def needs_general_retry(query: str, response_text: str) -> bool:
if not response_text:
return True
normalized_query = normalize_text(query)
identity_query = any(token in normalized_query for token in ("who created", "creator", "founder", "who are you"))
if identity_query:
return False
if is_intro_or_identity_query(query):
return False
if _is_generic_self_intro(response_text):
return True
# Avoid forcing retries for long-form coding answers; retries can degrade code quality.
if _contains_code_content(response_text):
return False
return looks_like_hallucinated_identity_claim(response_text)