Lab5 / app.py
Denysyk's picture
Upload 2 files
85484cb verified
"""
Medical RAG System
"""
import gradio as gr
import numpy as np
import json
import re
from typing import Optional
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from groq import Groq
# 1. ЗАВАНТАЖЕННЯ ДАНИХ
print("Завантажуємо датасет medmcqa...")
dataset = load_dataset("medmcqa", split="validation[:500]", trust_remote_code=True)
# Формуємо документи з питань + пояснень
raw_docs = []
for item in dataset:
explanation = item.get("exp") or ""
question = item["question"]
options = [item.get(f"op{k}", "") for k in ["a", "b", "c", "d"]]
correct_key = ["a", "b", "c", "d"][item["cop"]]
correct_answer = item.get(f"op{correct_key}", "")
subject = item.get("subject_name", "")
text = f"Question: {question}\n"
text += f"Options: A) {options[0]} B) {options[1]} C) {options[2]} D) {options[3]}\n"
text += f"Answer: {correct_answer}\n"
if explanation:
text += f"Explanation: {explanation}"
raw_docs.append({
"text": text,
"subject": subject,
"question": question,
"answer": correct_answer,
})
print(f"Завантажено {len(raw_docs)} документів")
# 2. CHUNKING
def chunk_documents(docs, chunk_size=300, overlap=50):
chunks = []
for idx, doc in enumerate(docs):
text = doc["text"]
words = text.split()
if len(words) <= chunk_size:
chunks.append({
"text": text,
"source_id": idx,
"subject": doc["subject"],
"question": doc["question"],
"answer": doc["answer"],
})
else:
start = 0
while start < len(words):
end = min(start + chunk_size, len(words))
chunk_text = " ".join(words[start:end])
chunks.append({
"text": chunk_text,
"source_id": idx,
"subject": doc["subject"],
"question": doc["question"],
"answer": doc["answer"],
})
if end == len(words):
break
start += chunk_size - overlap
return chunks
chunks = chunk_documents(raw_docs)
print(f"Отримано {len(chunks)} чанків після chunking")
# 3. BM25 RETRIEVER
tokenized_corpus = [c["text"].lower().split() for c in chunks]
bm25 = BM25Okapi(tokenized_corpus)
def bm25_search(query: str, top_k: int = 5):
tokenized_query = query.lower().split()
scores = bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
return [(chunks[i], float(scores[i])) for i in top_indices if scores[i] > 0]
# 4. SEMANTIC (DENSE) RETRIEVER
print("Завантажуємо модель для семантичного пошуку...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
print("Обчислюємо ембедінги для всіх чанків...")
chunk_texts = [c["text"] for c in chunks]
chunk_embeddings = embedder.encode(chunk_texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True)
print("Ембедінги готові!")
def semantic_search(query: str, top_k: int = 5):
query_emb = embedder.encode([query], convert_to_numpy=True)[0]
norms = np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_emb)
norms = np.where(norms == 0, 1e-9, norms)
scores = chunk_embeddings @ query_emb / norms
top_indices = np.argsort(scores)[::-1][:top_k]
return [(chunks[i], float(scores[i])) for i in top_indices]
# 5. HYBRID RETRIEVER
def hybrid_search(query: str, top_k: int = 5, use_bm25: bool = True, use_semantic: bool = True):
if not use_bm25 and not use_semantic:
return []
results = {}
if use_bm25:
bm25_results = bm25_search(query, top_k=top_k * 2)
for rank, (chunk, score) in enumerate(bm25_results):
key = chunk["text"][:80]
results[key] = results.get(key, {"chunk": chunk, "score": 0})
results[key]["score"] += 1 / (rank + 1) # reciprocal rank fusion
if use_semantic:
sem_results = semantic_search(query, top_k=top_k * 2)
for rank, (chunk, score) in enumerate(sem_results):
key = chunk["text"][:80]
results[key] = results.get(key, {"chunk": chunk, "score": 0})
results[key]["score"] += 1 / (rank + 1)
sorted_results = sorted(results.values(), key=lambda x: x["score"], reverse=True)
return [(r["chunk"], r["score"]) for r in sorted_results[:top_k]]
# 6. LLM (GROQ)
def generate_answer(query: str, context_chunks: list, groq_api_key: str) -> str:
client = Groq(api_key=groq_api_key)
context_parts = []
for i, (chunk, score) in enumerate(context_chunks, 1):
context_parts.append(f"[{i}] {chunk['text']}")
context = "\n\n".join(context_parts)
system_prompt = """You are a helpful medical assistant. Answer the user's question based ONLY on the provided context.
- Cite sources using square brackets like [1], [2] when you use information from them.
- If the context doesn't contain enough information, say so honestly.
- Be concise and accurate.
- Always mention relevant medical details from the context."""
user_prompt = f"""Context:
{context}
Question: {query}
Answer (with citations like [1], [2]):"""
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.2,
max_tokens=600,
)
return response.choices[0].message.content
# 7. ГОЛОВНА ФУНКЦІЯ RAG
def rag_query(
query: str,
groq_api_key: str,
use_bm25: bool,
use_semantic: bool,
top_k: int,
):
if not query.strip():
return "⚠️ Введіть запитання.", "", ""
if not groq_api_key.strip():
return "⚠️ Введіть Groq API ключ.", "", ""
if not use_bm25 and not use_semantic:
return "⚠️ Увімкніть хоча б один метод пошуку.", "", ""
try:
# Пошук релевантних чанків
retrieved = hybrid_search(query, top_k=top_k, use_bm25=use_bm25, use_semantic=use_semantic)
if not retrieved:
return "Не знайдено релевантних документів.", "", ""
# Генерація відповіді
answer = generate_answer(query, retrieved, groq_api_key)
sources_md = "### Джерела (чанки)\n\n"
for i, (chunk, score) in enumerate(retrieved, 1):
subj = chunk.get("subject", "—")
sources_md += f"**[{i}]** *(Subject: {subj}, Score: {score:.4f})*\n\n"
sources_md += f"```\n{chunk['text'][:400]}{'...' if len(chunk['text']) > 400 else ''}\n```\n\n"
methods = []
if use_bm25 and use_semantic:
methods.append("Hybrid (BM25 + Semantic)")
elif use_bm25:
methods.append("BM25 (keyword)")
else:
methods.append("Semantic (dense)")
info = f"**Метод пошуку:** {', '.join(methods)} | **Знайдено чанків:** {len(retrieved)}"
return answer, sources_md, info
except Exception as e:
return f"Помилка: {str(e)}", "", ""
# ──────────────────────────────────────────────
# 8. GRADIO UI
# ──────────────────────────────────────────────
with gr.Blocks(title="Medical RAG System", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
#Medical RAG System
**Retrieval-Augmented Generation** для відповідей на медичні запитання.
Система використовує датасет **medmcqa** (медичні питання з іспитів) та поєднує
BM25 (пошук по ключових словах) і семантичний пошук для знаходження релевантних джерел.
""")
with gr.Row():
with gr.Column(scale=2):
query_input = gr.Textbox(
label="Ваше запитання",
placeholder="Напр.: What is the mechanism of action of aspirin?",
lines=2,
)
api_key_input = gr.Textbox(
label="Groq API Key",
placeholder="gsk_...",
type="password",
)
with gr.Column(scale=1):
use_bm25 = gr.Checkbox(label="BM25 (keyword search)", value=True)
use_semantic = gr.Checkbox(label="Semantic search", value=True)
top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Кількість чанків (top-k)")
submit_btn = gr.Button("🔍 Знайти відповідь", variant="primary")
info_box = gr.Markdown("")
with gr.Tabs():
with gr.Tab("Відповідь"):
answer_output = gr.Markdown(label="Відповідь")
with gr.Tab("Джерела"):
sources_output = gr.Markdown(label="Використані чанки")
gr.Examples(
examples=[
["What is the mechanism of action of aspirin?", True, True, 5],
["Which vitamin deficiency causes night blindness?", True, True, 5],
["What are symptoms of diabetes mellitus?", False, True, 5],
["beta blocker mechanism", True, False, 5],
],
inputs=[query_input, use_bm25, use_semantic, top_k],
label="Приклади запитів",
)
submit_btn.click(
fn=rag_query,
inputs=[query_input, api_key_input, use_bm25, use_semantic, top_k],
outputs=[answer_output, sources_output, info_box],
)
gr.Markdown("""
---
**Датасет:** [medmcqa](https://huggingface.co/datasets/medmcqa) |
**LLM:** Groq (llama3-8b-8192) |
**Embeddings:** all-MiniLM-L6-v2 |
**Chunking:** sliding window (300 слів, overlap 50)
""")
if __name__ == "__main__":
demo.launch(share=True)