| """UTN Student Chatbot — Gradio app with CRAG pipeline.""" |
|
|
| import logging |
| import re |
|
|
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from prompt import REWRITE_PROMPT, build_chat_messages |
| from retriever import Retriever |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| MODEL_ID = "saeedbenadeeb/UTN-Qwen3-0.6B-LoRA-merged" |
|
|
| logger.info("Initializing retriever...") |
| retriever = Retriever( |
| faiss_index_path="faiss.index", |
| chunks_meta_path="chunks_meta.jsonl", |
| embedding_model="BAAI/bge-small-en-v1.5", |
| top_k=5, |
| ) |
|
|
| logger.info("Loading model: %s", MODEL_ID) |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| ).to(device) |
| model.eval() |
| logger.info("Model loaded.") |
|
|
|
|
| def _generate(messages: list[dict], max_tokens: int = 512) -> str: |
| prompt = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, |
| ) |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=0.3, |
| top_p=0.9, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() |
|
|
|
|
| def _grade_relevance(question: str, sources: list[dict]) -> bool: |
| if not sources: |
| return False |
| top_score = sources[0].get("score", 0) |
| q_tokens = set(re.findall(r"\w+", question.lower())) |
| doc_tokens = set(re.findall(r"\w+", sources[0].get("text", "").lower())) |
| stopwords = { |
| "i", "a", "the", "is", "it", "to", "do", "if", "my", "can", "in", "of", |
| "for", "and", "or", "at", "on", "no", "not", "what", "how", "when", "where", |
| "who", "which", "this", "that", "be", "are", "was", "have", "has", |
| } |
| q_content = q_tokens - stopwords |
| overlap = len(q_content & doc_tokens) / max(len(q_content), 1) |
| return top_score >= 0.02 or overlap >= 0.35 |
|
|
|
|
| def crag_answer(message: str, history: list[dict]) -> str: |
| question = message.strip() |
| if not question: |
| return "Please ask a question about UTN." |
|
|
| sources = retriever.retrieve(question) |
| relevant = _grade_relevance(question, sources) |
|
|
| if not relevant: |
| rewrite_msgs = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": REWRITE_PROMPT.format(question=question)}, |
| ] |
| rewritten = _generate(rewrite_msgs, max_tokens=100) |
| rewritten = rewritten.split("\n")[0].strip() |
| if rewritten and rewritten != question: |
| sources = retriever.retrieve(rewritten) |
|
|
| context = retriever.format_context(sources) |
| messages = build_chat_messages(question, context) |
| answer = _generate(messages) |
|
|
| return answer |
|
|
|
|
| demo = gr.ChatInterface( |
| fn=crag_answer, |
| type="messages", |
| title="UTN Student Chatbot", |
| description="Ask questions about the University of Technology Nuremberg (UTN) — admissions, programs, courses, deadlines, and more. Powered by a finetuned Qwen3-0.6B with Corrective RAG.", |
| examples=[ |
| "What are the admission requirements for AI & Robotics?", |
| "Are there tuition fees?", |
| "What courses are in the first semester?", |
| "Is there a Welcome Week?", |
| "What TOEFL score do I need?", |
| ], |
| theme=gr.themes.Soft(), |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|