Spaces:
Running
Running
| import os | |
| import traceback | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| # ========================================================= | |
| # Configuration | |
| # ========================================================= | |
| MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | |
| DOC_FILE = "general.md" | |
| MAX_NEW_TOKENS = 200 | |
| TOP_K = 3 | |
| # ========================================================= | |
| # Resolve path | |
| # ========================================================= | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DOC_PATH = os.path.join(BASE_DIR, DOC_FILE) | |
| if not os.path.exists(DOC_PATH): | |
| raise RuntimeError(f"❌ {DOC_FILE} not found next to app.py") | |
| # ========================================================= | |
| # Load Qwen Model | |
| # ========================================================= | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| # ========================================================= | |
| # Embedding Model (CPU friendly) | |
| # ========================================================= | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| # ========================================================= | |
| # Document Chunking | |
| # ========================================================= | |
| def chunk_text(text, chunk_size=300, overlap=50): | |
| words = text.split() | |
| chunks = [] | |
| i = 0 | |
| while i < len(words): | |
| chunk = words[i:i + chunk_size] | |
| chunks.append(" ".join(chunk)) | |
| i += chunk_size - overlap | |
| return chunks | |
| with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f: | |
| DOC_TEXT = f.read() | |
| DOC_CHUNKS = chunk_text(DOC_TEXT) | |
| DOC_EMBEDS = embedder.encode( | |
| DOC_CHUNKS, | |
| normalize_embeddings=True, | |
| show_progress_bar=True | |
| ) | |
| # ========================================================= | |
| # Retrieval | |
| # ========================================================= | |
| def retrieve_context(question, k=TOP_K): | |
| q_emb = embedder.encode([question], normalize_embeddings=True) | |
| scores = np.dot(DOC_EMBEDS, q_emb[0]) | |
| top_ids = scores.argsort()[-k:][::-1] | |
| return "\n\n".join([DOC_CHUNKS[i] for i in top_ids]) | |
| # ========================================================= | |
| # Clean Answer Extraction (CRITICAL) | |
| # ========================================================= | |
| def extract_final_answer(text: str) -> str: | |
| text = text.strip() | |
| # Remove prompt echoes | |
| markers = ["assistant:", "assistant", "answer:", "final answer:"] | |
| for m in markers: | |
| if m.lower() in text.lower(): | |
| text = text.lower().split(m, 1)[-1].strip() | |
| # Last line fallback | |
| lines = [l.strip() for l in text.split("\n") if l.strip()] | |
| return lines[-1] if lines else text | |
| # ========================================================= | |
| # Qwen Inference (ONLY ANSWER) | |
| # ========================================================= | |
| def answer_question(question): | |
| context = retrieve_context(question) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a strict document-based Q&A assistant.\n" | |
| "Answer ONLY the question.\n" | |
| "Do NOT repeat the context or the question.\n" | |
| "Respond in 1–2 sentences.\n" | |
| "If the answer is not present, say:\n" | |
| "'I could not find this information in the document.'" | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Context:\n{context}\n\nQuestion:\n{question}" | |
| } | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=0.3, | |
| do_sample=True | |
| ) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return extract_final_answer(decoded) | |
| # ========================================================= | |
| # Gradio Chat (ONLY Q & A) | |
| # ========================================================= | |
| def chat(user_message, history): | |
| if not user_message.strip(): | |
| return "", history | |
| try: | |
| answer = answer_question(user_message) | |
| except Exception as e: | |
| answer = "⚠️ An error occurred while generating the answer." | |
| history.append((user_message, answer)) | |
| return "", history | |
| def reset_chat(): | |
| return [] | |
| # ========================================================= | |
| # UI | |
| # ========================================================= | |
| def build_ui(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| chatbot = gr.Chatbot( | |
| height=420, | |
| type="tuples", | |
| avatar_images=("👤", "🤖") | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Ask a question...", | |
| lines=2, | |
| scale=8 | |
| ) | |
| send = gr.Button("🚀 Send", scale=2) | |
| clear = gr.Button("🧹 Clear") | |
| send.click(chat, [msg, chatbot], [msg, chatbot]) | |
| msg.submit(chat, [msg, chatbot], [msg, chatbot]) | |
| clear.click(reset_chat, outputs=chatbot) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |
| return demo | |
| # ========================================================= | |
| # Entrypoint | |
| # ========================================================= | |
| if __name__ == "__main__": | |
| print(f"✅ Loaded {len(DOC_CHUNKS)} chunks from {DOC_FILE}") | |
| print(f"✅ Model: {MODEL_ID}") | |
| build_ui() | |