"""UTN Student Chatbot — Gradio app with CRAG pipeline.""" import logging import re import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from prompt import REWRITE_PROMPT, build_chat_messages from retriever import Retriever logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MODEL_ID = "saeedbenadeeb/UTN-Qwen3-0.6B-LoRA-merged" logger.info("Initializing retriever...") retriever = Retriever( faiss_index_path="faiss.index", chunks_meta_path="chunks_meta.jsonl", embedding_model="BAAI/bge-small-en-v1.5", top_k=5, ) logger.info("Loading model: %s", MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=dtype, trust_remote_code=True, ).to(device) model.eval() logger.info("Model loaded.") def _generate(messages: list[dict], max_tokens: int = 512) -> str: prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_tokens, temperature=0.3, top_p=0.9, do_sample=True, pad_token_id=tokenizer.pad_token_id, ) return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() def _grade_relevance(question: str, sources: list[dict]) -> bool: if not sources: return False top_score = sources[0].get("score", 0) q_tokens = set(re.findall(r"\w+", question.lower())) doc_tokens = set(re.findall(r"\w+", sources[0].get("text", "").lower())) stopwords = { "i", "a", "the", "is", "it", "to", "do", "if", "my", "can", "in", "of", "for", "and", "or", "at", "on", "no", "not", "what", "how", "when", "where", "who", "which", "this", "that", "be", "are", "was", "have", "has", } q_content = q_tokens - stopwords overlap = len(q_content & doc_tokens) / max(len(q_content), 1) return top_score >= 0.02 or overlap >= 0.35 def crag_answer(message: str, history: list[dict]) -> str: question = message.strip() if not question: return "Please ask a question about UTN." sources = retriever.retrieve(question) relevant = _grade_relevance(question, sources) if not relevant: rewrite_msgs = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": REWRITE_PROMPT.format(question=question)}, ] rewritten = _generate(rewrite_msgs, max_tokens=100) rewritten = rewritten.split("\n")[0].strip() if rewritten and rewritten != question: sources = retriever.retrieve(rewritten) context = retriever.format_context(sources) messages = build_chat_messages(question, context) answer = _generate(messages) return answer demo = gr.ChatInterface( fn=crag_answer, type="messages", title="UTN Student Chatbot", description="Ask questions about the University of Technology Nuremberg (UTN) — admissions, programs, courses, deadlines, and more. Powered by a finetuned Qwen3-0.6B with Corrective RAG.", examples=[ "What are the admission requirements for AI & Robotics?", "Are there tuition fees?", "What courses are in the first semester?", "Is there a Welcome Week?", "What TOEFL score do I need?", ], theme=gr.themes.Soft(), ) if __name__ == "__main__": demo.launch()