Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import math | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from typing import Any | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| DEFAULT_DATA_FILE = os.path.join(os.path.dirname(__file__), "data", "krce_college_data.jsonl") | |
| DEFAULT_EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
| ABSTAIN_MESSAGE = "I don't know from the KRCE knowledge base." | |
| # Keep this simple: only a minimal relevance threshold. | |
| MIN_CONFIDENCE = 0.25 | |
| TOP_K = 3 | |
| SEARCH_STOPWORDS = { | |
| "a", "an", "and", "are", "at", "be", "for", "from", "how", "in", "is", "it", "of", "on", "or", | |
| "the", "to", "what", "when", "where", "who", "with", "your", "please", "tell", "me", "about", | |
| } | |
| # Lightweight post-generation safety net. | |
| HALLUCINATION_MARKERS = ( | |
| "created by", | |
| "created independently", | |
| "created after leaving", | |
| "des created me", | |
| "i was created", | |
| "krish cs my creator", | |
| "my creator", | |
| "my founder", | |
| ) | |
| GENERAL_KNOWLEDGE_MARKERS = ( | |
| "algorithm", | |
| "array", | |
| "binary tree", | |
| "coding", | |
| "computer science", | |
| "data structure", | |
| "debug", | |
| "explain", | |
| "merge sort", | |
| "python", | |
| "quick sort", | |
| "sorting", | |
| "stack", | |
| ) | |
| LIST_QUERY_MARKERS = ( | |
| "all", | |
| "boys", | |
| "faculty", | |
| "faculties", | |
| "girls", | |
| "list", | |
| "members", | |
| "restroom", | |
| "restrooms", | |
| "staff", | |
| "staffs", | |
| "washroom", | |
| "washrooms", | |
| "who are", | |
| ) | |
| TRAILING_QUERY_NOISE_MARKERS = ( | |
| ", tell me about ", | |
| ", who are ", | |
| ", who is ", | |
| ", how many ", | |
| ", i m a cse student", | |
| ", i am a cse student", | |
| ", is dr ", | |
| ", krce cse", | |
| ", my hod if", | |
| ) | |
| NAME_PATTERN = re.compile(r"\b(?:Dr|Mr|Mrs|Ms)\.\s*[A-Za-z][A-Za-z\s.]{1,70}") | |
| class RagIndex: | |
| model: SentenceTransformer | None | |
| records: list[dict[str, str]] | |
| documents: list[str] | |
| embeddings: np.ndarray | None | |
| tokenized_documents: list[list[str]] | |
| idf: dict[str, float] | |
| def normalize_text(text: str) -> str: | |
| text = text.lower().replace("'", " ").replace("/", " ").replace("-", " ") | |
| text = re.sub(r"[^a-z0-9\s.]+", " ", text) | |
| text = text.replace(".", " ") | |
| return re.sub(r"\s+", " ", text).strip() | |
| def _tokenize_for_search(text: str) -> list[str]: | |
| normalized = normalize_text(text) | |
| tokens = [token for token in normalized.split() if token and token not in SEARCH_STOPWORDS] | |
| return tokens | |
| def _build_idf(tokenized_documents: list[list[str]]) -> dict[str, float]: | |
| if not tokenized_documents: | |
| return {} | |
| doc_freq: dict[str, int] = {} | |
| total_docs = len(tokenized_documents) | |
| for tokens in tokenized_documents: | |
| unique_tokens = set(tokens) | |
| for token in unique_tokens: | |
| doc_freq[token] = doc_freq.get(token, 0) + 1 | |
| idf: dict[str, float] = {} | |
| for token, freq in doc_freq.items(): | |
| idf[token] = math.log((total_docs + 1.0) / (freq + 1.0)) + 1.0 | |
| return idf | |
| def _lexical_score(query_tokens: list[str], doc_tokens: list[str], idf: dict[str, float]) -> float: | |
| if not query_tokens or not doc_tokens: | |
| return 0.0 | |
| doc_set = set(doc_tokens) | |
| weighted_overlap = sum(idf.get(token, 1.0) for token in query_tokens if token in doc_set) | |
| weighted_total = sum(idf.get(token, 1.0) for token in query_tokens) | |
| if weighted_total <= 0: | |
| return 0.0 | |
| return weighted_overlap / weighted_total | |
| def _clean_output_text(output: str) -> str: | |
| cleaned = output.strip() | |
| lowered = cleaned.lower() | |
| cut_positions = [] | |
| for marker in TRAILING_QUERY_NOISE_MARKERS: | |
| pos = lowered.find(marker) | |
| if pos != -1: | |
| cut_positions.append(pos) | |
| if cut_positions: | |
| cleaned = cleaned[: min(cut_positions)].rstrip(" ,;") | |
| return cleaned | |
| def is_krce_scope_query(query: str) -> bool: | |
| normalized = normalize_text(query) | |
| # Minimal scope check to decide when to force abstain on low confidence. | |
| krce_terms = ( | |
| "krce", | |
| "k ramakrishnan", | |
| "college", | |
| "department", | |
| "faculty", | |
| "hod", | |
| "principal", | |
| "professor", | |
| "cse", | |
| "ece", | |
| "eee", | |
| "ai ds", | |
| "aids", | |
| "csbs", | |
| ) | |
| return any(term in normalized for term in krce_terms) | |
| def classify_query_route(query: str) -> str: | |
| normalized = normalize_text(query) | |
| krce_scope = is_krce_scope_query(query) | |
| general_scope = any(marker in normalized for marker in GENERAL_KNOWLEDGE_MARKERS) | |
| if krce_scope and general_scope: | |
| return "hybrid" | |
| if krce_scope: | |
| return "krce" | |
| return "general" | |
| def _load_records(data_file: str) -> list[dict[str, str]]: | |
| records: list[dict[str, str]] = [] | |
| with open(data_file, "r", encoding="utf-8") as handle: | |
| for line in handle: | |
| if not line.strip(): | |
| continue | |
| try: | |
| item = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| instruction = str(item.get("instruction", "")).strip() | |
| output = _clean_output_text(str(item.get("output", ""))) | |
| if not instruction and not output: | |
| continue | |
| records.append( | |
| { | |
| "instruction": instruction, | |
| "output": output, | |
| } | |
| ) | |
| return records | |
| def load_rag_index(data_file: str = DEFAULT_DATA_FILE, embedding_model: str = DEFAULT_EMBEDDING_MODEL) -> RagIndex: | |
| if not os.path.exists(data_file): | |
| return RagIndex(model=None, records=[], documents=[], embeddings=None, tokenized_documents=[], idf={}) | |
| try: | |
| model = SentenceTransformer(embedding_model) | |
| except Exception: | |
| return RagIndex(model=None, records=[], documents=[], embeddings=None, tokenized_documents=[], idf={}) | |
| records = _load_records(data_file) | |
| documents = [f"{record['instruction']}\n{record['output']}".strip() for record in records] | |
| if documents: | |
| embeddings = model.encode(documents, normalize_embeddings=True, convert_to_numpy=True) | |
| else: | |
| embeddings = np.empty((0, 0), dtype=np.float32) | |
| tokenized_documents = [_tokenize_for_search(doc) for doc in documents] | |
| idf = _build_idf(tokenized_documents) | |
| return RagIndex( | |
| model=model, | |
| records=records, | |
| documents=documents, | |
| embeddings=embeddings, | |
| tokenized_documents=tokenized_documents, | |
| idf=idf, | |
| ) | |
| def search_krce(query: str, rag_index: RagIndex, top_k: int = TOP_K) -> dict[str, Any]: | |
| if rag_index.model is None or rag_index.embeddings is None or not rag_index.records: | |
| return { | |
| "query": query, | |
| "context": "", | |
| "hits": [], | |
| "confidence": 0.0, | |
| "should_abstain": True, | |
| "abstain_reason": "RAG index is unavailable.", | |
| } | |
| query_embedding = rag_index.model.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0] | |
| vector_scores = np.dot(rag_index.embeddings, query_embedding).astype(float) | |
| query_tokens = _tokenize_for_search(query) | |
| lexical_scores = np.array( | |
| [_lexical_score(query_tokens, doc_tokens, rag_index.idf) for doc_tokens in rag_index.tokenized_documents], | |
| dtype=float, | |
| ) | |
| # Hybrid ranking: dense similarity for semantics + lexical overlap for exact KRCE entities. | |
| scores = (0.78 * vector_scores) + (0.22 * lexical_scores) | |
| if scores.size == 0: | |
| return { | |
| "query": query, | |
| "context": "", | |
| "hits": [], | |
| "confidence": 0.0, | |
| "should_abstain": True, | |
| "abstain_reason": ABSTAIN_MESSAGE, | |
| } | |
| ranked_indices = scores.argsort()[::-1] | |
| best_score = float(scores[ranked_indices[0]]) | |
| if best_score < MIN_CONFIDENCE: | |
| return { | |
| "query": query, | |
| "context": "", | |
| "hits": [], | |
| "confidence": best_score, | |
| "should_abstain": True, | |
| "abstain_reason": ABSTAIN_MESSAGE, | |
| } | |
| selected_indices = ranked_indices[: max(top_k, 5)] | |
| hits: list[dict[str, Any]] = [] | |
| blocks: list[str] = [] | |
| for rank, idx in enumerate(selected_indices, start=1): | |
| score = float(scores[idx]) | |
| vector_score = float(vector_scores[idx]) | |
| lexical_score = float(lexical_scores[idx]) | |
| record = rag_index.records[int(idx)] | |
| hits.append( | |
| { | |
| "rank": rank, | |
| "instruction": record["instruction"], | |
| "output": record["output"], | |
| "combined_score": score, | |
| "vector_score": vector_score, | |
| "lexical_score": lexical_score, | |
| "specific_overlap": 0.0, | |
| "role_overlap": 0.0, | |
| } | |
| ) | |
| blocks.append( | |
| f"[KB-{rank} | score={score:.3f}]\n" | |
| f"Question: {record['instruction']}\n" | |
| f"Answer: {record['output']}" | |
| ) | |
| return { | |
| "query": query, | |
| "context": "\n\n".join(blocks), | |
| "hits": hits, | |
| "confidence": best_score, | |
| "should_abstain": False, | |
| "abstain_reason": "", | |
| } | |
| def build_system_prompt(now: str, query: str, rag_result: dict[str, Any] | None) -> str: | |
| prompt = ( | |
| f"You are Krish Mind, a grounded assistant for KRCE.\n" | |
| f"CURRENT TIME: {now}\n\n" | |
| "RULES:\n" | |
| "- For KRCE facts, answer only from the KRCE evidence block.\n" | |
| "- Synthesize the final answer in your own words; do not copy long raw blocks.\n" | |
| "- Remove duplicates and repeated names.\n" | |
| "- For list-style queries, return a clean bullet list.\n" | |
| "- If the evidence does not directly answer, reply exactly: I don't know from the KRCE knowledge base.\n" | |
| "- Do not invent people, roles, creator/founder claims, or hidden details.\n" | |
| "- Keep the answer short and factual.\n" | |
| ) | |
| if rag_result and rag_result.get("context"): | |
| prompt += ( | |
| f"\n[KRCE EVIDENCE]\n{rag_result['context']}\n[END KRCE EVIDENCE]\n" | |
| "Use this evidence only." | |
| ) | |
| else: | |
| prompt += "\nNo KRCE evidence was retrieved." | |
| return prompt | |
| def build_general_system_prompt(now: str) -> str: | |
| return ( | |
| f"You are Krish Mind, a helpful AI assistant.\n" | |
| f"CURRENT TIME: {now}\n\n" | |
| "RULES:\n" | |
| "- Answer clearly and accurately using your own knowledge.\n" | |
| "- Keep replies compact by default (typically 4-10 lines unless user asks for full detail).\n" | |
| "- Use clean Markdown: short paragraphs, bullets for lists, fenced code blocks for code.\n" | |
| "- Avoid very long single lines; wrap explanations into readable short lines.\n" | |
| "- Do not mention creator/founder identity unless the user explicitly asks about it.\n" | |
| "- Do not claim personal origin stories that are not asked by the user.\n" | |
| "- Keep answers concise and structured.\n" | |
| ) | |
| def build_hybrid_system_prompt(now: str, rag_result: dict[str, Any] | None) -> str: | |
| prompt = ( | |
| f"You are Krish Mind, a helpful AI assistant for KRCE-related questions.\n" | |
| f"CURRENT TIME: {now}\n\n" | |
| "RULES:\n" | |
| "- Use KRCE evidence when available for college-specific facts.\n" | |
| "- For general explanation details not present in KRCE evidence, use your own knowledge.\n" | |
| "- Do not invent creator/founder identity claims.\n" | |
| ) | |
| if rag_result and rag_result.get("context"): | |
| prompt += f"\n[KRCE EVIDENCE]\n{rag_result['context']}\n[END KRCE EVIDENCE]\n" | |
| return prompt | |
| def looks_like_hallucinated_identity_claim(text: str) -> bool: | |
| normalized = normalize_text(text) | |
| return any(marker in normalized for marker in HALLUCINATION_MARKERS) | |
| def _contains_code_content(text: str) -> bool: | |
| lowered = text.lower() | |
| if "```" in text: | |
| return True | |
| code_markers = ( | |
| "def ", | |
| "class ", | |
| "#include", | |
| "public static void main", | |
| "void ", | |
| "int main", | |
| ) | |
| return any(marker in lowered for marker in code_markers) | |
| def _remove_identity_lines(text: str) -> str: | |
| lines = text.splitlines() | |
| kept = [] | |
| for line in lines: | |
| if looks_like_hallucinated_identity_claim(line): | |
| continue | |
| kept.append(line) | |
| cleaned = "\n".join(kept).strip() | |
| return cleaned | |
| def _is_generic_self_intro(text: str) -> bool: | |
| normalized = normalize_text(text) | |
| if not normalized: | |
| return False | |
| intro_prefixes = ( | |
| "i am krish mind", | |
| "i m krish mind", | |
| "hello i am krish mind", | |
| "hi i am krish mind", | |
| ) | |
| return any(normalized.startswith(prefix) for prefix in intro_prefixes) | |
| def is_generic_self_intro(text: str) -> bool: | |
| return _is_generic_self_intro(text) | |
| def is_intro_or_identity_query(query: str) -> bool: | |
| normalized = normalize_text(query) | |
| intro_markers = ( | |
| "hi", | |
| "hello", | |
| "hey", | |
| "good morning", | |
| "good afternoon", | |
| "good evening", | |
| "who are you", | |
| "introduce yourself", | |
| "your name", | |
| "tell me about yourself", | |
| ) | |
| return any(marker in normalized for marker in intro_markers) | |
| def _extract_people_names(text: str) -> list[str]: | |
| found = NAME_PATTERN.findall(text) | |
| cleaned: list[str] = [] | |
| seen = set() | |
| for item in found: | |
| name = re.sub(r"\s+", " ", item).strip(" ,.;") | |
| name = re.sub(r"\s+(at|in)\s+krce\b", "", name, flags=re.IGNORECASE) | |
| name = re.sub(r"\s+in\s+(cse|ece|eee|it|csbs|aids)\b", "", name, flags=re.IGNORECASE) | |
| name = re.sub(r"\.(\s*(professors?|labs?|department).*)$", "", name, flags=re.IGNORECASE) | |
| name = name.strip(" ,.;") | |
| key = normalize_text(name) | |
| if len(name) < 6: | |
| continue | |
| if any(bad in key for bad in ("professor", "lab", "department", "krce", "tell me", "who are")): | |
| continue | |
| if "tell me about" in key or "who are" in key: | |
| continue | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| cleaned.append(name) | |
| return cleaned | |
| def build_deterministic_krce_answer(query: str, rag_result: dict[str, Any]) -> str: | |
| normalized_query = normalize_text(query) | |
| location_intent = ("where" in normalized_query and "department" in normalized_query) | |
| list_intent = any(marker in normalized_query for marker in ("staff", "staffs", "faculty", "members", "list")) | |
| factual_direct_intent = any( | |
| token in normalized_query | |
| for token in ( | |
| "who is", | |
| "principal", | |
| "chairman", | |
| "vice principal", | |
| "controller of examinations", | |
| "deputy controller", | |
| "hod", | |
| "coordinator", | |
| "contact", | |
| "email", | |
| "working hours", | |
| "bus", | |
| "attendance", | |
| "mobile phone", | |
| "dress code", | |
| ) | |
| ) | |
| if not list_intent and not location_intent and not factual_direct_intent: | |
| return "" | |
| hits = rag_result.get("hits") or [] | |
| if not hits: | |
| return "" | |
| department_key = "" | |
| for dep in ("cse", "ece", "eee", "it", "csbs", "ai ds", "aids"): | |
| if re.search(rf"\b{re.escape(dep)}\b", normalized_query): | |
| department_key = dep | |
| break | |
| filtered_hits = hits | |
| if department_key: | |
| scoped_hits = [] | |
| for hit in hits: | |
| merged = f"{hit.get('instruction', '')} {hit.get('output', '')}" | |
| if re.search(rf"\b{re.escape(department_key)}\b", normalize_text(merged)): | |
| scoped_hits.append(hit) | |
| if scoped_hits: | |
| filtered_hits = scoped_hits | |
| if factual_direct_intent and not list_intent and not location_intent: | |
| if filtered_hits: | |
| first = str(filtered_hits[0].get("output", "")).strip() | |
| if first: | |
| return first | |
| if location_intent: | |
| floor_pattern = re.compile(r"\b(ground|first|second|third|fourth|fifth)\s+floor\b", re.IGNORECASE) | |
| for hit in filtered_hits: | |
| output = str(hit.get("output", "")) | |
| floor_match = floor_pattern.search(output) | |
| if floor_match: | |
| sentence = output.strip().split(".")[0].strip() | |
| if sentence: | |
| return sentence + "." | |
| all_names: list[str] = [] | |
| seen = set() | |
| for hit in filtered_hits: | |
| output = str(hit.get("output", "")) | |
| for name in _extract_people_names(output): | |
| key = normalize_text(name) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| all_names.append(name) | |
| if not all_names: | |
| return "" | |
| if re.search(r"\b(male|boys|boy)\b", normalized_query): | |
| filtered = [name for name in all_names if name.startswith(("Mr.",))] | |
| if filtered: | |
| all_names = filtered | |
| elif re.search(r"\b(female|girls|girl)\b", normalized_query): | |
| filtered = [name for name in all_names if name.startswith(("Mrs.", "Ms."))] | |
| if filtered: | |
| all_names = filtered | |
| department = "" | |
| for dep in ("cse", "ece", "eee", "it", "csbs", "ai ds", "aids"): | |
| if dep in normalized_query: | |
| department = dep.upper() | |
| break | |
| heading = f"{department} staff list:" if department else "Staff list:" | |
| bullet_lines = "\n".join(f"- {name}" for name in all_names[:60]) | |
| return f"{heading}\n{bullet_lines}" | |
| def compose_krce_response(query: str, rag_result: dict[str, Any]) -> str: | |
| hits = rag_result.get("hits") or [] | |
| if not hits: | |
| return ABSTAIN_MESSAGE | |
| normalized_query = normalize_text(query) | |
| is_list_query = any(marker in normalized_query for marker in LIST_QUERY_MARKERS) | |
| if not is_list_query: | |
| return str(hits[0].get("output", "")).strip() or ABSTAIN_MESSAGE | |
| unique_outputs: list[str] = [] | |
| seen = set() | |
| for hit in hits: | |
| output = str(hit.get("output", "")).strip() | |
| if not output: | |
| continue | |
| key = normalize_text(output) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| unique_outputs.append(output) | |
| if not unique_outputs: | |
| return ABSTAIN_MESSAGE | |
| if len(unique_outputs) == 1: | |
| return unique_outputs[0] | |
| return "\n".join(f"- {line}" for line in unique_outputs) | |
| def finalize_krce_response(query: str, response_text: str, rag_result: dict[str, Any] | None) -> str: | |
| if not response_text: | |
| return ABSTAIN_MESSAGE if is_krce_scope_query(query) else response_text | |
| if is_krce_scope_query(query): | |
| if looks_like_hallucinated_identity_claim(response_text): | |
| return ABSTAIN_MESSAGE | |
| if rag_result and rag_result.get("should_abstain"): | |
| return ABSTAIN_MESSAGE | |
| return response_text | |
| def finalize_general_response(query: str, response_text: str) -> str: | |
| if not response_text: | |
| return response_text | |
| normalized_query = normalize_text(query) | |
| identity_query = any(token in normalized_query for token in ("who created", "creator", "founder", "who are you")) | |
| intro_query = is_intro_or_identity_query(query) | |
| if identity_query: | |
| return response_text | |
| if intro_query: | |
| return response_text | |
| # For code answers, do not aggressively trim the full response. | |
| if _contains_code_content(response_text): | |
| cleaned_code_answer = _remove_identity_lines(response_text) | |
| return cleaned_code_answer or response_text | |
| if looks_like_hallucinated_identity_claim(response_text): | |
| cleaned = response_text | |
| lowered = normalize_text(response_text) | |
| cut_positions = [lowered.find(marker) for marker in HALLUCINATION_MARKERS if lowered.find(marker) != -1] | |
| if cut_positions: | |
| cut = min(cut_positions) | |
| cleaned = response_text[:cut].rstrip(" ,.;") | |
| if cleaned: | |
| return cleaned | |
| return "I can help with this topic. Please ask the question directly and I will answer clearly." | |
| return response_text | |
| def needs_general_retry(query: str, response_text: str) -> bool: | |
| if not response_text: | |
| return True | |
| normalized_query = normalize_text(query) | |
| identity_query = any(token in normalized_query for token in ("who created", "creator", "founder", "who are you")) | |
| if identity_query: | |
| return False | |
| if is_intro_or_identity_query(query): | |
| return False | |
| if _is_generic_self_intro(response_text): | |
| return True | |
| # Avoid forcing retries for long-form coding answers; retries can degrade code quality. | |
| if _contains_code_content(response_text): | |
| return False | |
| return looks_like_hallucinated_identity_claim(response_text) | |