File size: 3,965 Bytes
0cf8ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8db7c8
 
0cf8ad2
 
e8db7c8
0cf8ad2
e8db7c8
0cf8ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec87e2
0cf8ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""UTN Student Chatbot — Gradio app with CRAG pipeline."""

import logging
import re

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from prompt import REWRITE_PROMPT, build_chat_messages
from retriever import Retriever

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MODEL_ID = "saeedbenadeeb/UTN-Qwen3-0.6B-LoRA-merged"

logger.info("Initializing retriever...")
retriever = Retriever(
    faiss_index_path="faiss.index",
    chunks_meta_path="chunks_meta.jsonl",
    embedding_model="BAAI/bge-small-en-v1.5",
    top_k=5,
)

logger.info("Loading model: %s", MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    trust_remote_code=True,
).to(device)
model.eval()
logger.info("Model loaded.")


def _generate(messages: list[dict], max_tokens: int = 512) -> str:
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, enable_thinking=False,
    )
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.3,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()


def _grade_relevance(question: str, sources: list[dict]) -> bool:
    if not sources:
        return False
    top_score = sources[0].get("score", 0)
    q_tokens = set(re.findall(r"\w+", question.lower()))
    doc_tokens = set(re.findall(r"\w+", sources[0].get("text", "").lower()))
    stopwords = {
        "i", "a", "the", "is", "it", "to", "do", "if", "my", "can", "in", "of",
        "for", "and", "or", "at", "on", "no", "not", "what", "how", "when", "where",
        "who", "which", "this", "that", "be", "are", "was", "have", "has",
    }
    q_content = q_tokens - stopwords
    overlap = len(q_content & doc_tokens) / max(len(q_content), 1)
    return top_score >= 0.02 or overlap >= 0.35


def crag_answer(message: str, history: list[dict]) -> str:
    question = message.strip()
    if not question:
        return "Please ask a question about UTN."

    sources = retriever.retrieve(question)
    relevant = _grade_relevance(question, sources)

    if not relevant:
        rewrite_msgs = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": REWRITE_PROMPT.format(question=question)},
        ]
        rewritten = _generate(rewrite_msgs, max_tokens=100)
        rewritten = rewritten.split("\n")[0].strip()
        if rewritten and rewritten != question:
            sources = retriever.retrieve(rewritten)

    context = retriever.format_context(sources)
    messages = build_chat_messages(question, context)
    answer = _generate(messages)

    return answer


demo = gr.ChatInterface(
    fn=crag_answer,
    type="messages",
    title="UTN Student Chatbot",
    description="Ask questions about the University of Technology Nuremberg (UTN) — admissions, programs, courses, deadlines, and more. Powered by a finetuned Qwen3-0.6B with Corrective RAG.",
    examples=[
        "What are the admission requirements for AI & Robotics?",
        "Are there tuition fees?",
        "What courses are in the first semester?",
        "Is there a Welcome Week?",
        "What TOEFL score do I need?",
    ],
    theme=gr.themes.Soft(),
)

if __name__ == "__main__":
    demo.launch()