OhamLab-AI / app_qwen.py
rahul7star's picture
Update app_qwen.py
3912f7f verified
import os
import traceback
import gradio as gr
import torch
import spaces
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
# =========================================================
# Configuration
# =========================================================
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DOC_FILE = "general.md"
MAX_NEW_TOKENS = 200
TOP_K = 3
# =========================================================
# Resolve path
# =========================================================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DOC_PATH = os.path.join(BASE_DIR, DOC_FILE)
if not os.path.exists(DOC_PATH):
raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py")
# =========================================================
# Load Qwen Model
# =========================================================
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
model.eval()
# =========================================================
# Embedding Model (CPU friendly)
# =========================================================
embedder = SentenceTransformer("all-MiniLM-L6-v2")
# =========================================================
# Document Chunking
# =========================================================
def chunk_text(text, chunk_size=300, overlap=50):
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk = words[i:i + chunk_size]
chunks.append(" ".join(chunk))
i += chunk_size - overlap
return chunks
with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f:
DOC_TEXT = f.read()
DOC_CHUNKS = chunk_text(DOC_TEXT)
DOC_EMBEDS = embedder.encode(
DOC_CHUNKS,
normalize_embeddings=True,
show_progress_bar=True
)
# =========================================================
# Retrieval
# =========================================================
def retrieve_context(question, k=TOP_K):
q_emb = embedder.encode([question], normalize_embeddings=True)
scores = np.dot(DOC_EMBEDS, q_emb[0])
top_ids = scores.argsort()[-k:][::-1]
return "\n\n".join([DOC_CHUNKS[i] for i in top_ids])
# =========================================================
# Clean Answer Extraction (CRITICAL)
# =========================================================
def extract_final_answer(text: str) -> str:
text = text.strip()
# Remove prompt echoes
markers = ["assistant:", "assistant", "answer:", "final answer:"]
for m in markers:
if m.lower() in text.lower():
text = text.lower().split(m, 1)[-1].strip()
# Last line fallback
lines = [l.strip() for l in text.split("\n") if l.strip()]
return lines[-1] if lines else text
# =========================================================
# Qwen Inference (ONLY ANSWER)
# =========================================================
def answer_question(question):
context = retrieve_context(question)
messages = [
{
"role": "system",
"content": (
"You are a strict document-based Q&A assistant.\n"
"Answer ONLY the question.\n"
"Do NOT repeat the context or the question.\n"
"Respond in 1–2 sentences.\n"
"If the answer is not present, say:\n"
"'I could not find this information in the document.'"
)
},
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion:\n{question}"
}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=0.3,
do_sample=True
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return extract_final_answer(decoded)
# =========================================================
# Gradio Chat (ONLY Q & A)
# =========================================================
@spaces.GPU()
def chat(user_message, history):
if not user_message.strip():
return "", history
try:
answer = answer_question(user_message)
except Exception as e:
answer = "⚠️ An error occurred while generating the answer."
history.append((user_message, answer))
return "", history
def reset_chat():
return []
# =========================================================
# UI
# =========================================================
def build_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chatbot = gr.Chatbot(
height=420,
type="tuples",
avatar_images=("👤", "🤖")
)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask a question...",
lines=2,
scale=8
)
send = gr.Button("🚀 Send", scale=2)
clear = gr.Button("🧹 Clear")
send.click(chat, [msg, chatbot], [msg, chatbot])
msg.submit(chat, [msg, chatbot], [msg, chatbot])
clear.click(reset_chat, outputs=chatbot)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
return demo
# =========================================================
# Entrypoint
# =========================================================
if __name__ == "__main__":
print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}")
print(f"✅ Model: {MODEL_ID}")
build_ui()