import os import traceback import gradio as gr import torch import spaces import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer # ========================================================= # Configuration # ========================================================= MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" DOC_FILE = "general.md" MAX_NEW_TOKENS = 200 TOP_K = 3 # ========================================================= # Resolve path # ========================================================= BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DOC_PATH = os.path.join(BASE_DIR, DOC_FILE) if not os.path.exists(DOC_PATH): raise RuntimeError(f"โŒ {DOC_FILE} not found next to app.py") # ========================================================= # Load Qwen Model # ========================================================= tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) model.eval() # ========================================================= # Embedding Model (CPU friendly) # ========================================================= embedder = SentenceTransformer("all-MiniLM-L6-v2") # ========================================================= # Document Chunking # ========================================================= def chunk_text(text, chunk_size=300, overlap=50): words = text.split() chunks = [] i = 0 while i < len(words): chunk = words[i:i + chunk_size] chunks.append(" ".join(chunk)) i += chunk_size - overlap return chunks with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f: DOC_TEXT = f.read() DOC_CHUNKS = chunk_text(DOC_TEXT) DOC_EMBEDS = embedder.encode( DOC_CHUNKS, normalize_embeddings=True, show_progress_bar=True ) # ========================================================= # Retrieval # ========================================================= def retrieve_context(question, k=TOP_K): q_emb = embedder.encode([question], normalize_embeddings=True) scores = np.dot(DOC_EMBEDS, q_emb[0]) top_ids = scores.argsort()[-k:][::-1] return "\n\n".join([DOC_CHUNKS[i] for i in top_ids]) # ========================================================= # Clean Answer Extraction (CRITICAL) # ========================================================= def extract_final_answer(text: str) -> str: text = text.strip() # Remove prompt echoes markers = ["assistant:", "assistant", "answer:", "final answer:"] for m in markers: if m.lower() in text.lower(): text = text.lower().split(m, 1)[-1].strip() # Last line fallback lines = [l.strip() for l in text.split("\n") if l.strip()] return lines[-1] if lines else text # ========================================================= # Qwen Inference (ONLY ANSWER) # ========================================================= def answer_question(question): context = retrieve_context(question) messages = [ { "role": "system", "content": ( "You are a strict document-based Q&A assistant.\n" "Answer ONLY the question.\n" "Do NOT repeat the context or the question.\n" "Respond in 1โ€“2 sentences.\n" "If the answer is not present, say:\n" "'I could not find this information in the document.'" ) }, { "role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}" } ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=0.3, do_sample=True ) decoded = tokenizer.decode(output[0], skip_special_tokens=True) return extract_final_answer(decoded) # ========================================================= # Gradio Chat (ONLY Q & A) # ========================================================= @spaces.GPU() def chat(user_message, history): if not user_message.strip(): return "", history try: answer = answer_question(user_message) except Exception as e: answer = "โš ๏ธ An error occurred while generating the answer." history.append((user_message, answer)) return "", history def reset_chat(): return [] # ========================================================= # UI # ========================================================= def build_ui(): with gr.Blocks(theme=gr.themes.Soft()) as demo: chatbot = gr.Chatbot( height=420, type="tuples", avatar_images=("๐Ÿ‘ค", "๐Ÿค–") ) with gr.Row(): msg = gr.Textbox( placeholder="Ask a question...", lines=2, scale=8 ) send = gr.Button("๐Ÿš€ Send", scale=2) clear = gr.Button("๐Ÿงน Clear") send.click(chat, [msg, chatbot], [msg, chatbot]) msg.submit(chat, [msg, chatbot], [msg, chatbot]) clear.click(reset_chat, outputs=chatbot) demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) return demo # ========================================================= # Entrypoint # ========================================================= if __name__ == "__main__": print(f"โœ… Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}") print(f"โœ… Model: {MODEL_ID}") build_ui()