saeedbenadeeb commited on
Commit
0cf8ad2
·
verified ·
1 Parent(s): d267d41

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UTN Student Chatbot — Gradio app with CRAG pipeline."""
2
+
3
+ import logging
4
+ import re
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ from prompt import REWRITE_PROMPT, build_chat_messages
12
+ from retriever import Retriever
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ MODEL_ID = "saeedbenadeeb/UTN-Qwen3-0.6B-LoRA-merged"
18
+
19
+ logger.info("Initializing retriever...")
20
+ retriever = Retriever(
21
+ faiss_index_path="faiss.index",
22
+ chunks_meta_path="chunks_meta.jsonl",
23
+ embedding_model="BAAI/bge-small-en-v1.5",
24
+ top_k=5,
25
+ )
26
+
27
+ logger.info("Loading model: %s", MODEL_ID)
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
29
+ if tokenizer.pad_token is None:
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ MODEL_ID,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ trust_remote_code=True,
37
+ )
38
+ model.eval()
39
+ logger.info("Model loaded.")
40
+
41
+
42
+ def _generate(messages: list[dict], max_tokens: int = 512) -> str:
43
+ prompt = tokenizer.apply_chat_template(
44
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False,
45
+ )
46
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
47
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
48
+ with torch.no_grad():
49
+ out = model.generate(
50
+ **inputs,
51
+ max_new_tokens=max_tokens,
52
+ temperature=0.3,
53
+ top_p=0.9,
54
+ do_sample=True,
55
+ pad_token_id=tokenizer.pad_token_id,
56
+ )
57
+ return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
58
+
59
+
60
+ def _grade_relevance(question: str, sources: list[dict]) -> bool:
61
+ if not sources:
62
+ return False
63
+ top_score = sources[0].get("score", 0)
64
+ q_tokens = set(re.findall(r"\w+", question.lower()))
65
+ doc_tokens = set(re.findall(r"\w+", sources[0].get("text", "").lower()))
66
+ stopwords = {
67
+ "i", "a", "the", "is", "it", "to", "do", "if", "my", "can", "in", "of",
68
+ "for", "and", "or", "at", "on", "no", "not", "what", "how", "when", "where",
69
+ "who", "which", "this", "that", "be", "are", "was", "have", "has",
70
+ }
71
+ q_content = q_tokens - stopwords
72
+ overlap = len(q_content & doc_tokens) / max(len(q_content), 1)
73
+ return top_score >= 0.02 or overlap >= 0.35
74
+
75
+
76
+ @spaces.GPU
77
+ def crag_answer(message: str, history: list[dict]) -> str:
78
+ question = message.strip()
79
+ if not question:
80
+ return "Please ask a question about UTN."
81
+
82
+ sources = retriever.retrieve(question)
83
+ relevant = _grade_relevance(question, sources)
84
+
85
+ if not relevant:
86
+ rewrite_msgs = [
87
+ {"role": "system", "content": "You are a helpful assistant."},
88
+ {"role": "user", "content": REWRITE_PROMPT.format(question=question)},
89
+ ]
90
+ rewritten = _generate(rewrite_msgs, max_tokens=100)
91
+ rewritten = rewritten.split("\n")[0].strip()
92
+ if rewritten and rewritten != question:
93
+ sources = retriever.retrieve(rewritten)
94
+
95
+ context = retriever.format_context(sources)
96
+ messages = build_chat_messages(question, context)
97
+ answer = _generate(messages)
98
+
99
+ sources_md = retriever.format_sources_markdown(sources)
100
+ return answer + sources_md
101
+
102
+
103
+ demo = gr.ChatInterface(
104
+ fn=crag_answer,
105
+ type="messages",
106
+ title="UTN Student Chatbot",
107
+ description="Ask questions about the University of Technology Nuremberg (UTN) — admissions, programs, courses, deadlines, and more. Powered by a finetuned Qwen3-0.6B with Corrective RAG.",
108
+ examples=[
109
+ "What are the admission requirements for AI & Robotics?",
110
+ "Are there tuition fees?",
111
+ "What courses are in the first semester?",
112
+ "Is there a Welcome Week?",
113
+ "What TOEFL score do I need?",
114
+ ],
115
+ theme=gr.themes.Soft(),
116
+ )
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()