saeedbenadeeb's picture
Upload app.py with huggingface_hub
1ec87e2 verified
"""UTN Student Chatbot — Gradio app with CRAG pipeline."""
import logging
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from prompt import REWRITE_PROMPT, build_chat_messages
from retriever import Retriever
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_ID = "saeedbenadeeb/UTN-Qwen3-0.6B-LoRA-merged"
logger.info("Initializing retriever...")
retriever = Retriever(
faiss_index_path="faiss.index",
chunks_meta_path="chunks_meta.jsonl",
embedding_model="BAAI/bge-small-en-v1.5",
top_k=5,
)
logger.info("Loading model: %s", MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
trust_remote_code=True,
).to(device)
model.eval()
logger.info("Model loaded.")
def _generate(messages: list[dict], max_tokens: int = 512) -> str:
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False,
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.3,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
def _grade_relevance(question: str, sources: list[dict]) -> bool:
if not sources:
return False
top_score = sources[0].get("score", 0)
q_tokens = set(re.findall(r"\w+", question.lower()))
doc_tokens = set(re.findall(r"\w+", sources[0].get("text", "").lower()))
stopwords = {
"i", "a", "the", "is", "it", "to", "do", "if", "my", "can", "in", "of",
"for", "and", "or", "at", "on", "no", "not", "what", "how", "when", "where",
"who", "which", "this", "that", "be", "are", "was", "have", "has",
}
q_content = q_tokens - stopwords
overlap = len(q_content & doc_tokens) / max(len(q_content), 1)
return top_score >= 0.02 or overlap >= 0.35
def crag_answer(message: str, history: list[dict]) -> str:
question = message.strip()
if not question:
return "Please ask a question about UTN."
sources = retriever.retrieve(question)
relevant = _grade_relevance(question, sources)
if not relevant:
rewrite_msgs = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": REWRITE_PROMPT.format(question=question)},
]
rewritten = _generate(rewrite_msgs, max_tokens=100)
rewritten = rewritten.split("\n")[0].strip()
if rewritten and rewritten != question:
sources = retriever.retrieve(rewritten)
context = retriever.format_context(sources)
messages = build_chat_messages(question, context)
answer = _generate(messages)
return answer
demo = gr.ChatInterface(
fn=crag_answer,
type="messages",
title="UTN Student Chatbot",
description="Ask questions about the University of Technology Nuremberg (UTN) — admissions, programs, courses, deadlines, and more. Powered by a finetuned Qwen3-0.6B with Corrective RAG.",
examples=[
"What are the admission requirements for AI & Robotics?",
"Are there tuition fees?",
"What courses are in the first semester?",
"Is there a Welcome Week?",
"What TOEFL score do I need?",
],
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
demo.launch()