rahul7star commited on
Commit
2737e4c
·
verified ·
1 Parent(s): 0bc8034

Update app_qwen.py

Browse files
Files changed (1) hide show
  1. app_qwen.py +148 -181
app_qwen.py CHANGED
@@ -1,27 +1,34 @@
1
- import spaces
2
  import os
3
- import textwrap
4
  import traceback
5
  import gradio as gr
6
  import torch
 
 
7
 
8
- from transformers import (
9
- pipeline,
10
- AutoTokenizer,
11
- AutoModelForCausalLM,
12
- )
13
 
14
- # ---------------------------
15
  # Configuration
16
- # ---------------------------
17
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
 
 
 
 
18
 
19
- ROOT_DIR = "."
20
- ALLOWED_EXT = (".txt", ".md")
 
 
 
21
 
22
- # ---------------------------
23
- # Load lightweight model
24
- # ---------------------------
 
 
 
25
  tokenizer = AutoTokenizer.from_pretrained(
26
  MODEL_ID,
27
  trust_remote_code=True
@@ -34,203 +41,163 @@ model = AutoModelForCausalLM.from_pretrained(
34
  trust_remote_code=True
35
  )
36
 
37
- pipe = pipeline(
38
- "text-generation",
39
- model=model,
40
- tokenizer=tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
 
43
- # ---------------------------
44
- # Research loader (project root)
45
- # ---------------------------
46
- def load_research_from_root(max_total_chars: int = 12000):
47
- files = []
48
- for name in sorted(os.listdir(ROOT_DIR)):
49
- if name.lower().endswith(ALLOWED_EXT) and name != "requirements.txt":
50
- if name == os.path.basename(__file__):
51
- continue
52
- files.append(name)
53
-
54
- if not files:
55
- return "No research files (.txt/.md) found in project root."
56
-
57
- combined_parts, total_len = [], 0
58
-
59
- for fname in files:
60
- try:
61
- with open(os.path.join(ROOT_DIR, fname), "r", encoding="utf-8", errors="ignore") as f:
62
- txt = f.read()
63
- except Exception as e:
64
- txt = f"[Error reading {fname}: {e}]"
65
-
66
- if len(txt) > 8000:
67
- txt = txt[:8000] + "\n\n[TRUNCATED]\n"
68
-
69
- part = f"--- {fname} ---\n{txt.strip()}\n"
70
- combined_parts.append(part)
71
- total_len += len(part)
72
-
73
- if total_len >= max_total_chars:
74
- break
75
-
76
- combined = "\n\n".join(combined_parts)
77
- return combined[:max_total_chars]
78
-
79
- # ---------------------------
80
- # System prompts
81
- # ---------------------------
82
- research_context = load_research_from_root()
83
-
84
- def get_system_prompt(mode="chat"):
85
- if mode == "chat":
86
- return textwrap.dedent(f"""
87
- You are OhamLab AI.
88
-
89
- Mode: Conversational Q&A.
90
-
91
- Rules:
92
- - Answer clearly in 3–6 sentences.
93
- - Prefer accuracy over creativity.
94
- - Use the research context to answer questions.
95
- - Treat markdown headings as semantic sections.
96
- - If the answer is not in the research context, say so.
97
-
98
- --- BEGIN RESEARCH CONTEXT ---
99
- {research_context}
100
- --- END RESEARCH CONTEXT ---
101
- """).strip()
102
-
103
- return textwrap.dedent(f"""
104
- You are OhamLab AI.
105
-
106
- Mode: Research / Analytical.
107
-
108
- Rules:
109
- - Use structured reasoning and sections.
110
- - Reference the research context when relevant.
111
- - Be precise and analytical.
112
- - Treat markdown headings as semantic structure.
113
-
114
- --- BEGIN RESEARCH CONTEXT ---
115
- {research_context}
116
- --- END RESEARCH CONTEXT ---
117
- """).strip()
118
-
119
- # ---------------------------
120
- # State
121
- # ---------------------------
122
- conversation_mode = "chat"
123
- history_messages = [{"role": "system", "content": get_system_prompt("chat")}]
124
- chat_history_for_ui = []
125
-
126
- # ---------------------------
127
- # Model call helper
128
- # ---------------------------
129
- def call_model_get_response(messages, max_tokens=600):
130
- conversation_text = ""
131
-
132
- for m in messages:
133
- role = m["role"].upper()
134
- conversation_text += f"[{role}]: {m['content']}\n"
135
-
136
- conversation_text += "[ASSISTANT]:"
137
 
138
- try:
139
- output = pipe(
140
- conversation_text,
141
- max_new_tokens=max_tokens,
 
 
 
 
 
 
 
 
 
142
  do_sample=True,
143
- temperature=0.5,
144
- top_p=0.9,
145
- repetition_penalty=1.1,
146
- return_full_text=False,
147
  )
148
- return output[0]["generated_text"].strip()
149
 
150
- except Exception as e:
151
- tb = traceback.format_exc()
152
- return f"⚠️ Error: {e}\n\n{tb.splitlines()[-6:]}"
153
 
154
- # ---------------------------
155
- # Chat logic
156
- # ---------------------------
157
  @spaces.GPU()
158
- def chat_with_model(user_message, chat_history):
159
- global history_messages, chat_history_for_ui, conversation_mode
160
-
161
  if not user_message.strip():
162
- return "", chat_history
163
 
164
- msg_lower = user_message.lower()
165
-
166
- if "switch to research mode" in msg_lower:
167
- conversation_mode = "research"
168
- history_messages = [{"role": "system", "content": get_system_prompt("research")}]
169
- return "", chat_history + [("🟢 Mode", "🔬 Research mode activated.")]
170
-
171
- if "switch to chat mode" in msg_lower:
172
- conversation_mode = "chat"
173
- history_messages = [{"role": "system", "content": get_system_prompt("chat")}]
174
- return "", chat_history + [("🟢 Mode", "💬 Chat mode activated.")]
175
-
176
- history_messages.append({"role": "user", "content": user_message})
177
-
178
- bot_text = call_model_get_response(history_messages)
179
-
180
- history_messages.append({"role": "assistant", "content": bot_text})
181
- chat_history_for_ui.append((user_message, bot_text))
182
 
183
- return "", chat_history_for_ui
 
184
 
185
  def reset_chat():
186
- global history_messages, chat_history_for_ui
187
- history_messages = [{"role": "system", "content": get_system_prompt(conversation_mode)}]
188
- chat_history_for_ui = []
189
  return []
190
 
191
- # ---------------------------
192
- # Gradio UI
193
- # ---------------------------
194
  def build_ui():
195
- with gr.Blocks(
196
- theme=gr.themes.Soft(),
197
- css="""
198
- #chatbot {
199
- background-color: #f9f9fb;
200
- border-radius: 12px;
201
- padding: 10px;
202
- }
203
- """
204
- ) as demo:
205
-
206
- with gr.Row():
207
- clear_btn = gr.Button("🧹 Clear", size="sm")
208
 
209
  chatbot = gr.Chatbot(
210
- height=400,
211
- type="tuples",
212
  avatar_images=("👤", "🤖"),
 
213
  )
214
 
215
  with gr.Row():
216
  msg = gr.Textbox(
217
- placeholder="Ask a question about the markdown files...",
218
  lines=2,
219
- scale=8,
220
  )
221
- send = gr.Button("🚀 Send", variant="primary", scale=2)
 
 
222
 
223
- send.click(chat_with_model, [msg, chatbot], [msg, chatbot])
224
- msg.submit(chat_with_model, [msg, chatbot], [msg, chatbot])
225
- clear_btn.click(reset_chat, outputs=chatbot)
226
 
227
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
228
 
229
  return demo
230
 
231
- # ---------------------------
232
  # Entrypoint
233
- # ---------------------------
234
  if __name__ == "__main__":
235
- print(f"✅ Starting app with model: {MODEL_ID}")
 
236
  build_ui()
 
 
1
  import os
 
2
  import traceback
3
  import gradio as gr
4
  import torch
5
+ import spaces
6
+ import numpy as np
7
 
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from sentence_transformers import SentenceTransformer
 
 
 
10
 
11
+ # =========================================================
12
  # Configuration
13
+ # =========================================================
14
  MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
15
+ GENERAL_MD = "general.md"
16
+
17
+ MAX_NEW_TOKENS = 300
18
+ TOP_K = 3
19
 
20
+ # =========================================================
21
+ # Resolve path (CRITICAL)
22
+ # =========================================================
23
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ GENERAL_MD_PATH = os.path.join(BASE_DIR, GENERAL_MD)
25
 
26
+ if not os.path.exists(GENERAL_MD_PATH):
27
+ raise RuntimeError(f"❌ {GENERAL_MD} not found next to app.py")
28
+
29
+ # =========================================================
30
+ # Load Model
31
+ # =========================================================
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  MODEL_ID,
34
  trust_remote_code=True
 
41
  trust_remote_code=True
42
  )
43
 
44
+ model.eval()
45
+
46
+ # =========================================================
47
+ # Embedding Model (CPU-friendly)
48
+ # =========================================================
49
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
50
+
51
+ # =========================================================
52
+ # Load & Chunk general.md
53
+ # =========================================================
54
+ def chunk_text(text, chunk_size=300, overlap=50):
55
+ words = text.split()
56
+ chunks = []
57
+ i = 0
58
+ while i < len(words):
59
+ chunk = words[i:i + chunk_size]
60
+ chunks.append(" ".join(chunk))
61
+ i += chunk_size - overlap
62
+ return chunks
63
+
64
+ with open(GENERAL_MD_PATH, "r", encoding="utf-8", errors="ignore") as f:
65
+ md_text = f.read()
66
+
67
+ DOC_CHUNKS = chunk_text(md_text)
68
+ DOC_SOURCES = [GENERAL_MD] * len(DOC_CHUNKS)
69
+
70
+ if not DOC_CHUNKS:
71
+ raise RuntimeError("❌ general.md is empty or unreadable")
72
+
73
+ # =========================================================
74
+ # Embed once
75
+ # =========================================================
76
+ DOC_EMBEDS = embedder.encode(
77
+ DOC_CHUNKS,
78
+ normalize_embeddings=True,
79
+ show_progress_bar=True
80
  )
81
 
82
+ # =========================================================
83
+ # Retrieval
84
+ # =========================================================
85
+ def retrieve_context(question, k=TOP_K):
86
+ q_emb = embedder.encode([question], normalize_embeddings=True)
87
+ scores = np.dot(DOC_EMBEDS, q_emb[0])
88
+ top_ids = scores.argsort()[-k:][::-1]
89
+
90
+ context = []
91
+ for i in top_ids:
92
+ context.append(f"[Source: {DOC_SOURCES[i]}]\n{DOC_CHUNKS[i]}")
93
+
94
+ return "\n\n".join(context)
95
+
96
+ # =========================================================
97
+ # Qwen ChatML Inference
98
+ # =========================================================
99
+ def answer_question(question):
100
+ context = retrieve_context(question)
101
+
102
+ messages = [
103
+ {
104
+ "role": "system",
105
+ "content": (
106
+ "You are a strict document-based Q&A assistant.\n"
107
+ "Answer ONLY using the provided context.\n"
108
+ "If the answer is not present, say:\n"
109
+ "'I could not find this information in the document.'"
110
+ )
111
+ },
112
+ {
113
+ "role": "user",
114
+ "content": f"""
115
+ Context:
116
+ {context}
117
+
118
+ Question:
119
+ {question}
120
+ """
121
+ }
122
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ prompt = tokenizer.apply_chat_template(
125
+ messages,
126
+ tokenize=False,
127
+ add_generation_prompt=True
128
+ )
129
+
130
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
131
+
132
+ with torch.no_grad():
133
+ output = model.generate(
134
+ **inputs,
135
+ max_new_tokens=MAX_NEW_TOKENS,
136
+ temperature=0.3,
137
  do_sample=True,
 
 
 
 
138
  )
 
139
 
140
+ return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
141
 
142
+ # =========================================================
143
+ # Gradio Chat
144
+ # =========================================================
145
  @spaces.GPU()
146
+ def chat(user_message, history):
 
 
147
  if not user_message.strip():
148
+ return "", history
149
 
150
+ try:
151
+ answer = answer_question(user_message)
152
+ except Exception as e:
153
+ tb = traceback.format_exc()
154
+ answer = f"⚠️ Error:\n{e}\n\n{tb}"
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ history.append((user_message, answer))
157
+ return "", history
158
 
159
  def reset_chat():
 
 
 
160
  return []
161
 
162
+ # =========================================================
163
+ # UI
164
+ # =========================================================
165
  def build_ui():
166
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
167
+ gr.Markdown("## 📄 Q&A from general.md (Qwen2.5-0.5B + RAG)")
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  chatbot = gr.Chatbot(
170
+ height=420,
 
171
  avatar_images=("👤", "🤖"),
172
+ type="tuples"
173
  )
174
 
175
  with gr.Row():
176
  msg = gr.Textbox(
177
+ placeholder="Ask a question from general.md...",
178
  lines=2,
179
+ scale=8
180
  )
181
+ send = gr.Button("🚀 Send", scale=2)
182
+
183
+ clear = gr.Button("🧹 Clear")
184
 
185
+ send.click(chat, [msg, chatbot], [msg, chatbot])
186
+ msg.submit(chat, [msg, chatbot], [msg, chatbot])
187
+ clear.click(reset_chat, outputs=chatbot)
188
 
189
+ demo.launch(
190
+ server_name="0.0.0.0",
191
+ server_port=7860,
192
+ share=False
193
+ )
194
 
195
  return demo
196
 
197
+ # =========================================================
198
  # Entrypoint
199
+ # =========================================================
200
  if __name__ == "__main__":
201
+ print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from general.md")
202
+ print(f"✅ Model: {MODEL_ID}")
203
  build_ui()