baocaomloncloud / app.py
Tiens0710's picture
Update app.py
78066c1 verified
import os, re, uuid
from pathlib import Path
_CACHE_DIR = "/data/hf_cache"
os.makedirs(_CACHE_DIR, exist_ok=True)
os.environ["HF_HOME"] = _CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = _CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = _CACHE_DIR
import gradio as gr
import chromadb
from chromadb.config import Settings
import fitz
from docx import Document as DocxDocument
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
EMBED_MODEL_ID = "intfloat/multilingual-e5-small"
LLM_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
COLLECTION_NAME = "rag_docs"
CHROMA_PATH = "/data/chromadb"
CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_NEW_TOKENS = 512, 64, 4, 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Loading embedding: {EMBED_MODEL_ID}")
embed_model = SentenceTransformer(EMBED_MODEL_ID, device=DEVICE)
print(f"[INFO] Loading LLM: {LLM_MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
llm_model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_ID,
dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None,
)
llm_pipeline = pipeline(
"text-generation", model=llm_model, tokenizer=tokenizer,
device=DEVICE if DEVICE == "cpu" else None,
)
# ── ChromaDB ──────────────────────────────────────────────────────────────────
def _make_chroma():
if os.path.exists("/data"):
os.makedirs(CHROMA_PATH, exist_ok=True)
return chromadb.PersistentClient(path=CHROMA_PATH)
return chromadb.Client(Settings(anonymized_telemetry=False))
chroma_client = _make_chroma()
collection = chroma_client.get_or_create_collection(
name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"})
# ── Document loaders ──────────────────────────────────────────────────────────
def load_pdf(p):
doc = fitz.open(p)
t = "\n\n".join(pg.get_text() for pg in doc)
doc.close()
return t
def load_docx(p):
return "\n\n".join(para.text for para in DocxDocument(p).paragraphs if para.text.strip())
def load_text(p):
return open(p, encoding="utf-8", errors="ignore").read()
def load_document(p):
ext = Path(p).suffix.lower()
if ext == ".pdf": return load_pdf(p)
if ext == ".docx": return load_docx(p)
if ext in (".txt", ".md"): return load_text(p)
raise ValueError(f"Unsupported: {ext}")
# ── Chunking ──────────────────────────────────────────────────────────────────
def split_text(text):
text = re.sub(r"\n{3,}", "\n\n", text).strip()
chunks, start = [], 0
while start < len(text):
end = start + CHUNK_SIZE
if end >= len(text):
chunks.append(text[start:])
break
sp = end
for sep in ["\n\n", "\n", ". "]:
idx = text.rfind(sep, start + CHUNK_OVERLAP, end)
if idx != -1:
sp = idx + len(sep)
break
chunks.append(text[start:sp])
start = sp - CHUNK_OVERLAP
return [c.strip() for c in chunks if len(c.strip()) > 30]
# ── Embeddings ────────────────────────────────────────────────────────────────
def embed_passages(texts):
return embed_model.encode(
[f"passage: {t}" for t in texts], show_progress_bar=False
).tolist()
def embed_query(q):
return embed_model.encode([f"query: {q}"]).tolist()
# ── RAG core ──────────────────────────────────────────────────────────────────
def index_document(file_obj):
if file_obj is None:
return "⚠️ Chưa chọn file.", collection.count()
filename = Path(file_obj.name).name
try:
text = load_document(file_obj.name)
except ValueError as e:
return f"❌ {e}", collection.count()
chunks = split_text(text)
if not chunks:
return "⚠️ TΓ i liệu rα»—ng.", collection.count()
embs = embed_passages(chunks)
ids = [f"{filename}_{i}_{uuid.uuid4().hex[:6]}" for i in range(len(chunks))]
metas = [{"source": filename, "chunk": i} for i in range(len(chunks))]
collection.add(documents=chunks, embeddings=embs, ids=ids, metadatas=metas)
return f"βœ… {filename} β€” {len(chunks)} chunks Δ‘Γ£ index.", collection.count()
def retrieve(query):
if collection.count() == 0:
return []
res = collection.query(
query_embeddings=embed_query(query),
n_results=min(TOP_K, collection.count()),
include=["documents", "metadatas", "distances"],
)
return [
{"text": d, "source": m["source"]}
for d, m, s in zip(res["documents"][0], res["metadatas"][0], res["distances"][0])
]
def answer_question(question, history):
if not question.strip():
return history, ""
if collection.count() == 0:
history.append({"role": "user", "content": question})
history.append({"role": "assistant", "content": "⚠️ ChΖ°a cΓ³ tΓ i liệu. Vui lΓ²ng upload trΖ°α»›c."})
return history, ""
chunks = retrieve(question)
if not chunks:
history.append({"role": "user", "content": question})
history.append({"role": "assistant", "content": "KhΓ΄ng tΓ¬m thαΊ₯y thΓ΄ng tin liΓͺn quan."})
return history, ""
context = "\n\n---\n\n".join(f"[{c['source']}]\n{c['text']}" for c in chunks)
system = ("Bẑn là trợ lý AI. Trả lời cÒu hỏi DỰA TRÊN ngữ cảnh. "
"NαΊΏu khΓ΄ng cΓ³ trong ngα»― cαΊ£nh, nΓ³i rΓ΅. TrαΊ£ lời bαΊ±ng tiαΊΏng Việt.")
msgs = [
{"role": "system", "content": system},
{"role": "user", "content": f"Ngữ cảnh:\n{context}\n\nCÒu hỏi: {question}"},
]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
outputs = llm_pipeline(prompt, max_new_tokens=MAX_NEW_TOKENS,
do_sample=False, return_full_text=False)
answer = outputs[0]["generated_text"].strip()
sources = list({c["source"] for c in chunks})
src_text = " \nπŸ“„ *Nguα»“n: " + " Β· ".join(sources) + "*"
history.append({"role": "user", "content": question})
history.append({"role": "assistant", "content": answer + src_text})
return history, ""
def reset_index():
global collection
chroma_client.delete_collection(COLLECTION_NAME)
collection = chroma_client.get_or_create_collection(
name=COLLECTION_NAME, metadata={"hnsw:space": "cosine"})
return "πŸ—‘οΈ Đã xΓ³a toΓ n bα»™.", 0
def get_count():
return collection.count()
# ── CSS ───────────────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=DM+Sans:ital,opsz,wght@0,9..40,300;0,9..40,400;0,9..40,500;0,9..40,600;1,9..40,400&family=DM+Mono:wght@400;500&display=swap');
*, *::before, *::after { box-sizing: border-box; }
body, html {
margin: 0; padding: 0;
background: #0b0d12 !important;
font-family: 'DM Sans', sans-serif !important;
}
footer, .built-with { display: none !important; }
.gradio-container {
max-width: 100% !important;
padding: 0 !important;
background: #0b0d12 !important;
min-height: 100vh;
}
/* ── Sidebar ── */
.sidebar-content {
background: #111318 !important;
border-right: 1px solid #1c2030 !important;
padding: 0 !important;
}
.brand-block {
padding: 20px 18px 16px;
border-bottom: 1px solid #1c2030;
margin-bottom: 8px;
}
.brand-name { font-size: 14px; font-weight: 600; color: #f1f5f9; margin: 0; }
.brand-sub {
display: flex; align-items: center; gap: 6px;
font-size: 11px; color: #4ade80; margin-top: 5px;
}
.dot {
width: 6px; height: 6px; background: #4ade80;
border-radius: 50%; animation: pulse 2s ease-in-out infinite; flex-shrink: 0;
}
@keyframes pulse { 0%,100%{opacity:1} 50%{opacity:.3} }
.section-label {
font-size: 9.5px; font-weight: 600;
letter-spacing: .12em; text-transform: uppercase;
color: #3d4a5c; padding: 14px 18px 6px;
}
/* File upload inside sidebar */
.sidebar-content .wrap { background: transparent !important; border: none !important; padding: 4px 14px !important; }
.sidebar-content [data-testid="file-upload"],
.sidebar-content .upload-container {
background: #0b0e14 !important;
border: 1.5px dashed #252d3d !important;
border-radius: 10px !important;
}
/* Buttons */
.sidebar-content button.lg {
background: linear-gradient(135deg,#2563eb,#1e40af) !important;
border: none !important; border-radius: 8px !important;
color: #fff !important; font-size: 13px !important; font-weight: 500 !important;
padding: 10px !important; width: 100% !important;
font-family: 'DM Sans', sans-serif !important; cursor: pointer !important;
margin-top: 6px !important;
}
.sidebar-content button.stop {
background: transparent !important; border: 1px solid #1c2030 !important;
border-radius: 8px !important; color: #ef4444 !important;
font-size: 12px !important; padding: 8px !important; width: 100% !important;
font-family: 'DM Sans', sans-serif !important; cursor: pointer !important;
margin-top: 4px !important;
}
.sidebar-content button.stop:hover { background: #1a0d0d !important; border-color: #ef4444 !important; }
/* Number / label in sidebar */
.sidebar-content label span { color: #64748b !important; font-size: 11px !important; }
.sidebar-content input[type=number] {
background: #0b0e14 !important; border: 1px solid #1c2030 !important;
border-radius: 8px !important; color: #60a5fa !important;
font-family: 'DM Mono', monospace !important; font-size: 18px !important;
font-weight: 500 !important; padding: 8px 12px !important; width: 100% !important;
}
.sidebar-content .prose p { font-size: 11.5px !important; color: #4b5563 !important; }
/* ── Main area ── */
.main-header {
display: flex; align-items: center; justify-content: space-between;
padding: 14px 24px; border-bottom: 1px solid #1c2030;
background: #0b0d12;
}
.main-header h2 {
font-size: 15px; font-weight: 600; color: #f1f5f9;
margin: 0; display: flex; align-items: center; gap: 8px;
}
.live-pill {
background: #0c1f12; border: 1px solid #14532d; color: #4ade80;
font-size: 10.5px; padding: 3px 10px; border-radius: 20px;
display: flex; align-items: center; gap: 5px;
}
/* ── Chatbot ── */
.chatbot-wrap [data-testid="chatbot"] {
background: transparent !important; border: none !important;
}
.chatbot-wrap .message-bubble-border { border-radius: 14px !important; }
.chatbot-wrap .message.user > div { background: #172554 !important; border: 1px solid #1e3a8a55 !important; }
.chatbot-wrap .message.bot > div,
.chatbot-wrap .message.assistant > div { background: #111318 !important; border: 1px solid #1c2030 !important; }
.chatbot-wrap .message p { color: #e2e8f0 !important; font-size: 14px !important; line-height: 1.6 !important; }
.chatbot-wrap .message span { color: #94a3b8 !important; }
/* ── Suggestion chips ── */
.sug-row { gap: 8px !important; padding: 10px 24px !important; flex-wrap: wrap; }
.sug-row button {
background: #111318 !important; border: 1px solid #1c2030 !important;
border-radius: 10px !important; color: #64748b !important;
font-size: 12px !important; padding: 7px 14px !important;
font-family: 'DM Sans', sans-serif !important; transition: all .2s !important;
white-space: nowrap !important; cursor: pointer !important;
}
.sug-row button:hover { background: #151c2e !important; border-color: #3b82f6 !important; color: #94a3b8 !important; }
/* ── Input row ── */
.input-row { padding: 12px 24px !important; border-top: 1px solid #1c2030 !important; align-items: flex-end !important; }
.input-row textarea {
background: #111318 !important; border: 1px solid #1c2030 !important;
border-radius: 12px !important; color: #e2e8f0 !important;
font-size: 14px !important; padding: 12px 16px !important;
resize: none !important; font-family: 'DM Sans', sans-serif !important;
transition: border-color .2s !important;
}
.input-row textarea:focus { border-color: #3b82f6 !important; outline: none !important; }
.input-row textarea::placeholder { color: #2d3748 !important; }
.input-row button.primary {
background: linear-gradient(135deg,#2563eb,#1e40af) !important;
border: none !important; border-radius: 10px !important; color: #fff !important;
font-size: 14px !important; font-weight: 600 !important; padding: 12px 22px !important;
white-space: nowrap !important; height: 46px !important;
font-family: 'DM Sans', sans-serif !important; cursor: pointer !important;
}
.input-row button.secondary {
background: transparent !important; border: 1px solid #1c2030 !important;
border-radius: 10px !important; color: #475569 !important;
font-size: 13px !important; padding: 12px 14px !important;
height: 46px !important; font-family: 'DM Sans', sans-serif !important; cursor: pointer !important;
}
"""
SUGS = [
"πŸ’‘ TΓ³m tαΊ―t nα»™i dung tΓ i liệu",
"✨ Tìm thông tin mÒu thuẫn",
"πŸ“… TrΓ­ch xuαΊ₯t tαΊ₯t cαΊ£ ngΓ y & mα»‘c",
]
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="RAG Β· Qwen2.5-0.5B", theme=gr.themes.Base(), css=CSS) as demo:
# ── Sidebar ───────────────────────────────────────────────────────────────
with gr.Sidebar(elem_classes="sidebar-content"):
gr.HTML("""
<div class="brand-block">
<div class="brand-name">RAG Β· Qwen2.5-0.5B</div>
<div class="brand-sub"><span class="dot"></span>Local Model Active</div>
</div>
<div class="section-label">Data Ingestion</div>
""")
file_input = gr.File(
file_types=[".pdf", ".docx", ".txt", ".md"],
show_label=False,
)
upload_btn = gr.Button("βŠ• Index Document", variant="primary")
status_vis = gr.Markdown("", elem_classes="prose")
chunk_vis = gr.Number(value=0, label="Chunks indexed", interactive=False)
gr.HTML('<div class="section-label">Actions</div>')
clear_btn = gr.Button("πŸ—‘οΈ Clear All Documents", variant="stop")
# ── Main content ──────────────────────────────────────────────────────────
gr.HTML("""
<div class="main-header">
<h2>πŸ’¬ Chat Interface</h2>
<div class="live-pill"><span class="dot" style="width:5px;height:5px"></span>Running</div>
</div>
""")
with gr.Column(elem_classes="chatbot-wrap"):
chatbot = gr.Chatbot(
height=460,
show_label=False,
render_markdown=True,
type="messages",
)
with gr.Row(elem_classes="sug-row"):
s1 = gr.Button(SUGS[0])
s2 = gr.Button(SUGS[1])
s3 = gr.Button(SUGS[2])
with gr.Row(elem_classes="input-row"):
question = gr.Textbox(
placeholder="Đặt cΓ’u hỏi về tΓ i liệu Δ‘Γ£ index...",
show_label=False, lines=1, scale=5,
)
ask_btn = gr.Button("➀ Ask", variant="primary", scale=1)
clr_btn = gr.Button("βœ•", variant="secondary", scale=0)
# ── Events ────────────────────────────────────────────────────────────────
upload_btn.click(fn=index_document, inputs=file_input, outputs=[status_vis, chunk_vis])
ask_btn.click(fn=answer_question, inputs=[question, chatbot], outputs=[chatbot, question])
question.submit(fn=answer_question, inputs=[question, chatbot], outputs=[chatbot, question])
clear_btn.click(fn=reset_index, outputs=[status_vis, chunk_vis])
clr_btn.click(lambda: [], outputs=chatbot)
s1.click(fn=lambda h: answer_question(SUGS[0][2:].strip(), h), inputs=chatbot, outputs=[chatbot, question])
s2.click(fn=lambda h: answer_question(SUGS[1][2:].strip(), h), inputs=chatbot, outputs=[chatbot, question])
s3.click(fn=lambda h: answer_question(SUGS[2][2:].strip(), h), inputs=chatbot, outputs=[chatbot, question])
demo.load(fn=get_count, outputs=chunk_vis)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", ssr_mode=False)