import os, re, uuid from pathlib import Path _CACHE_DIR = "/data/hf_cache" os.makedirs(_CACHE_DIR, exist_ok=True) os.environ["HF_HOME"] = _CACHE_DIR os.environ["TRANSFORMERS_CACHE"] = _CACHE_DIR os.environ["HF_DATASETS_CACHE"] = _CACHE_DIR import gradio as gr import chromadb from chromadb.config import Settings import fitz from docx import Document as DocxDocument from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch EMBED_MODEL_ID = "intfloat/multilingual-e5-small" LLM_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" COLLECTION_NAME = "rag_docs" CHROMA_PATH = "/data/chromadb" CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_NEW_TOKENS = 512, 64, 4, 512 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[INFO] Loading embedding: {EMBED_MODEL_ID}") embed_model = SentenceTransformer(EMBED_MODEL_ID, device=DEVICE) print(f"[INFO] Loading LLM: {LLM_MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID) llm_model = AutoModelForCausalLM.from_pretrained( LLM_MODEL_ID, dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto" if DEVICE == "cuda" else None, ) llm_pipeline = pipeline( "text-generation", model=llm_model, tokenizer=tokenizer, device=DEVICE if DEVICE == "cpu" else None, ) # ── ChromaDB ────────────────────────────────────────────────────────────────── def _make_chroma(): if os.path.exists("/data"): os.makedirs(CHROMA_PATH, exist_ok=True) return chromadb.PersistentClient(path=CHROMA_PATH) return chromadb.Client(Settings(anonymized_telemetry=False)) chroma_client = _make_chroma() collection = chroma_client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}) # ── Document loaders ────────────────────────────────────────────────────────── def load_pdf(p): doc = fitz.open(p) t = "\n\n".join(pg.get_text() for pg in doc) doc.close() return t def load_docx(p): return "\n\n".join(para.text for para in DocxDocument(p).paragraphs if para.text.strip()) def load_text(p): return open(p, encoding="utf-8", errors="ignore").read() def load_document(p): ext = Path(p).suffix.lower() if ext == ".pdf": return load_pdf(p) if ext == ".docx": return load_docx(p) if ext in (".txt", ".md"): return load_text(p) raise ValueError(f"Unsupported: {ext}") # ── Chunking ────────────────────────────────────────────────────────────────── def split_text(text): text = re.sub(r"\n{3,}", "\n\n", text).strip() chunks, start = [], 0 while start < len(text): end = start + CHUNK_SIZE if end >= len(text): chunks.append(text[start:]) break sp = end for sep in ["\n\n", "\n", ". "]: idx = text.rfind(sep, start + CHUNK_OVERLAP, end) if idx != -1: sp = idx + len(sep) break chunks.append(text[start:sp]) start = sp - CHUNK_OVERLAP return [c.strip() for c in chunks if len(c.strip()) > 30] # ── Embeddings ──────────────────────────────────────────────────────────────── def embed_passages(texts): return embed_model.encode( [f"passage: {t}" for t in texts], show_progress_bar=False ).tolist() def embed_query(q): return embed_model.encode([f"query: {q}"]).tolist() # ── RAG core ────────────────────────────────────────────────────────────────── def index_document(file_obj): if file_obj is None: return "⚠️ Chưa chọn file.", collection.count() filename = Path(file_obj.name).name try: text = load_document(file_obj.name) except ValueError as e: return f"❌ {e}", collection.count() chunks = split_text(text) if not chunks: return "⚠️ Tài liệu rỗng.", collection.count() embs = embed_passages(chunks) ids = [f"{filename}_{i}_{uuid.uuid4().hex[:6]}" for i in range(len(chunks))] metas = [{"source": filename, "chunk": i} for i in range(len(chunks))] collection.add(documents=chunks, embeddings=embs, ids=ids, metadatas=metas) return f"✅ {filename} — {len(chunks)} chunks đã index.", collection.count() def retrieve(query): if collection.count() == 0: return [] res = collection.query( query_embeddings=embed_query(query), n_results=min(TOP_K, collection.count()), include=["documents", "metadatas", "distances"], ) return [ {"text": d, "source": m["source"]} for d, m, s in zip(res["documents"][0], res["metadatas"][0], res["distances"][0]) ] def answer_question(question, history): if not question.strip(): return history, "" if collection.count() == 0: history.append({"role": "user", "content": question}) history.append({"role": "assistant", "content": "⚠️ Chưa có tài liệu. Vui lòng upload trước."}) return history, "" chunks = retrieve(question) if not chunks: history.append({"role": "user", "content": question}) history.append({"role": "assistant", "content": "Không tìm thấy thông tin liên quan."}) return history, "" context = "\n\n---\n\n".join(f"[{c['source']}]\n{c['text']}" for c in chunks) system = ("Bạn là trợ lý AI. Trả lời câu hỏi DỰA TRÊN ngữ cảnh. " "Nếu không có trong ngữ cảnh, nói rõ. Trả lời bằng tiếng Việt.") msgs = [ {"role": "system", "content": system}, {"role": "user", "content": f"Ngữ cảnh:\n{context}\n\nCâu hỏi: {question}"}, ] prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) outputs = llm_pipeline(prompt, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, return_full_text=False) answer = outputs[0]["generated_text"].strip() sources = list({c["source"] for c in chunks}) src_text = " \n📄 *Nguồn: " + " · ".join(sources) + "*" history.append({"role": "user", "content": question}) history.append({"role": "assistant", "content": answer + src_text}) return history, "" def reset_index(): global collection chroma_client.delete_collection(COLLECTION_NAME) collection = chroma_client.get_or_create_collection( name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"}) return "🗑️ Đã xóa toàn bộ.", 0 def get_count(): return collection.count() # ── CSS ─────────────────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=DM+Sans:ital,opsz,wght@0,9..40,300;0,9..40,400;0,9..40,500;0,9..40,600;1,9..40,400&family=DM+Mono:wght@400;500&display=swap'); *, *::before, *::after { box-sizing: border-box; } body, html { margin: 0; padding: 0; background: #0b0d12 !important; font-family: 'DM Sans', sans-serif !important; } footer, .built-with { display: none !important; } .gradio-container { max-width: 100% !important; padding: 0 !important; background: #0b0d12 !important; min-height: 100vh; } /* ── Sidebar ── */ .sidebar-content { background: #111318 !important; border-right: 1px solid #1c2030 !important; padding: 0 !important; } .brand-block { padding: 20px 18px 16px; border-bottom: 1px solid #1c2030; margin-bottom: 8px; } .brand-name { font-size: 14px; font-weight: 600; color: #f1f5f9; margin: 0; } .brand-sub { display: flex; align-items: center; gap: 6px; font-size: 11px; color: #4ade80; margin-top: 5px; } .dot { width: 6px; height: 6px; background: #4ade80; border-radius: 50%; animation: pulse 2s ease-in-out infinite; flex-shrink: 0; } @keyframes pulse { 0%,100%{opacity:1} 50%{opacity:.3} } .section-label { font-size: 9.5px; font-weight: 600; letter-spacing: .12em; text-transform: uppercase; color: #3d4a5c; padding: 14px 18px 6px; } /* File upload inside sidebar */ .sidebar-content .wrap { background: transparent !important; border: none !important; padding: 4px 14px !important; } .sidebar-content [data-testid="file-upload"], .sidebar-content .upload-container { background: #0b0e14 !important; border: 1.5px dashed #252d3d !important; border-radius: 10px !important; } /* Buttons */ .sidebar-content button.lg { background: linear-gradient(135deg,#2563eb,#1e40af) !important; border: none !important; border-radius: 8px !important; color: #fff !important; font-size: 13px !important; font-weight: 500 !important; padding: 10px !important; width: 100% !important; font-family: 'DM Sans', sans-serif !important; cursor: pointer !important; margin-top: 6px !important; } .sidebar-content button.stop { background: transparent !important; border: 1px solid #1c2030 !important; border-radius: 8px !important; color: #ef4444 !important; font-size: 12px !important; padding: 8px !important; width: 100% !important; font-family: 'DM Sans', sans-serif !important; cursor: pointer !important; margin-top: 4px !important; } .sidebar-content button.stop:hover { background: #1a0d0d !important; border-color: #ef4444 !important; } /* Number / label in sidebar */ .sidebar-content label span { color: #64748b !important; font-size: 11px !important; } .sidebar-content input[type=number] { background: #0b0e14 !important; border: 1px solid #1c2030 !important; border-radius: 8px !important; color: #60a5fa !important; font-family: 'DM Mono', monospace !important; font-size: 18px !important; font-weight: 500 !important; padding: 8px 12px !important; width: 100% !important; } .sidebar-content .prose p { font-size: 11.5px !important; color: #4b5563 !important; } /* ── Main area ── */ .main-header { display: flex; align-items: center; justify-content: space-between; padding: 14px 24px; border-bottom: 1px solid #1c2030; background: #0b0d12; } .main-header h2 { font-size: 15px; font-weight: 600; color: #f1f5f9; margin: 0; display: flex; align-items: center; gap: 8px; } .live-pill { background: #0c1f12; border: 1px solid #14532d; color: #4ade80; font-size: 10.5px; padding: 3px 10px; border-radius: 20px; display: flex; align-items: center; gap: 5px; } /* ── Chatbot ── */ .chatbot-wrap [data-testid="chatbot"] { background: transparent !important; border: none !important; } .chatbot-wrap .message-bubble-border { border-radius: 14px !important; } .chatbot-wrap .message.user > div { background: #172554 !important; border: 1px solid #1e3a8a55 !important; } .chatbot-wrap .message.bot > div, .chatbot-wrap .message.assistant > div { background: #111318 !important; border: 1px solid #1c2030 !important; } .chatbot-wrap .message p { color: #e2e8f0 !important; font-size: 14px !important; line-height: 1.6 !important; } .chatbot-wrap .message span { color: #94a3b8 !important; } /* ── Suggestion chips ── */ .sug-row { gap: 8px !important; padding: 10px 24px !important; flex-wrap: wrap; } .sug-row button { background: #111318 !important; border: 1px solid #1c2030 !important; border-radius: 10px !important; color: #64748b !important; font-size: 12px !important; padding: 7px 14px !important; font-family: 'DM Sans', sans-serif !important; transition: all .2s !important; white-space: nowrap !important; cursor: pointer !important; } .sug-row button:hover { background: #151c2e !important; border-color: #3b82f6 !important; color: #94a3b8 !important; } /* ── Input row ── */ .input-row { padding: 12px 24px !important; border-top: 1px solid #1c2030 !important; align-items: flex-end !important; } .input-row textarea { background: #111318 !important; border: 1px solid #1c2030 !important; border-radius: 12px !important; color: #e2e8f0 !important; font-size: 14px !important; padding: 12px 16px !important; resize: none !important; font-family: 'DM Sans', sans-serif !important; transition: border-color .2s !important; } .input-row textarea:focus { border-color: #3b82f6 !important; outline: none !important; } .input-row textarea::placeholder { color: #2d3748 !important; } .input-row button.primary { background: linear-gradient(135deg,#2563eb,#1e40af) !important; border: none !important; border-radius: 10px !important; color: #fff !important; font-size: 14px !important; font-weight: 600 !important; padding: 12px 22px !important; white-space: nowrap !important; height: 46px !important; font-family: 'DM Sans', sans-serif !important; cursor: pointer !important; } .input-row button.secondary { background: transparent !important; border: 1px solid #1c2030 !important; border-radius: 10px !important; color: #475569 !important; font-size: 13px !important; padding: 12px 14px !important; height: 46px !important; font-family: 'DM Sans', sans-serif !important; cursor: pointer !important; } """ SUGS = [ "💡 Tóm tắt nội dung tài liệu", "✨ Tìm thông tin mâu thuẫn", "📅 Trích xuất tất cả ngày & mốc", ] # ── UI ──────────────────────────────────────────────────────────────────────── with gr.Blocks(title="RAG · Qwen2.5-0.5B", theme=gr.themes.Base(), css=CSS) as demo: # ── Sidebar ─────────────────────────────────────────────────────────────── with gr.Sidebar(elem_classes="sidebar-content"): gr.HTML("""
RAG · Qwen2.5-0.5B
Local Model Active
Data Ingestion
""") file_input = gr.File( file_types=[".pdf", ".docx", ".txt", ".md"], show_label=False, ) upload_btn = gr.Button("⊕ Index Document", variant="primary") status_vis = gr.Markdown("", elem_classes="prose") chunk_vis = gr.Number(value=0, label="Chunks indexed", interactive=False) gr.HTML('
Actions
') clear_btn = gr.Button("🗑️ Clear All Documents", variant="stop") # ── Main content ────────────────────────────────────────────────────────── gr.HTML("""

💬 Chat Interface

Running
""") with gr.Column(elem_classes="chatbot-wrap"): chatbot = gr.Chatbot( height=460, show_label=False, render_markdown=True, type="messages", ) with gr.Row(elem_classes="sug-row"): s1 = gr.Button(SUGS[0]) s2 = gr.Button(SUGS[1]) s3 = gr.Button(SUGS[2]) with gr.Row(elem_classes="input-row"): question = gr.Textbox( placeholder="Đặt câu hỏi về tài liệu đã index...", show_label=False, lines=1, scale=5, ) ask_btn = gr.Button("➤ Ask", variant="primary", scale=1) clr_btn = gr.Button("✕", variant="secondary", scale=0) # ── Events ──────────────────────────────────────────────────────────────── upload_btn.click(fn=index_document, inputs=file_input, outputs=[status_vis, chunk_vis]) ask_btn.click(fn=answer_question, inputs=[question, chatbot], outputs=[chatbot, question]) question.submit(fn=answer_question, inputs=[question, chatbot], outputs=[chatbot, question]) clear_btn.click(fn=reset_index, outputs=[status_vis, chunk_vis]) clr_btn.click(lambda: [], outputs=chatbot) s1.click(fn=lambda h: answer_question(SUGS[0][2:].strip(), h), inputs=chatbot, outputs=[chatbot, question]) s2.click(fn=lambda h: answer_question(SUGS[1][2:].strip(), h), inputs=chatbot, outputs=[chatbot, question]) s3.click(fn=lambda h: answer_question(SUGS[2][2:].strip(), h), inputs=chatbot, outputs=[chatbot, question]) demo.load(fn=get_count, outputs=chunk_vis) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", ssr_mode=False)