| """ |
| Medical RAG System |
| """ |
|
|
| import gradio as gr |
| import numpy as np |
| import json |
| import re |
| from typing import Optional |
| from datasets import load_dataset |
| from rank_bm25 import BM25Okapi |
| from sentence_transformers import SentenceTransformer |
| from groq import Groq |
|
|
| |
|
|
| print("Завантажуємо датасет medmcqa...") |
| dataset = load_dataset("medmcqa", split="validation[:500]", trust_remote_code=True) |
|
|
| |
| raw_docs = [] |
| for item in dataset: |
| explanation = item.get("exp") or "" |
| question = item["question"] |
| options = [item.get(f"op{k}", "") for k in ["a", "b", "c", "d"]] |
| correct_key = ["a", "b", "c", "d"][item["cop"]] |
| correct_answer = item.get(f"op{correct_key}", "") |
| subject = item.get("subject_name", "") |
|
|
| text = f"Question: {question}\n" |
| text += f"Options: A) {options[0]} B) {options[1]} C) {options[2]} D) {options[3]}\n" |
| text += f"Answer: {correct_answer}\n" |
| if explanation: |
| text += f"Explanation: {explanation}" |
|
|
| raw_docs.append({ |
| "text": text, |
| "subject": subject, |
| "question": question, |
| "answer": correct_answer, |
| }) |
|
|
| print(f"Завантажено {len(raw_docs)} документів") |
|
|
| |
|
|
| def chunk_documents(docs, chunk_size=300, overlap=50): |
|
|
| chunks = [] |
| for idx, doc in enumerate(docs): |
| text = doc["text"] |
| words = text.split() |
| if len(words) <= chunk_size: |
| chunks.append({ |
| "text": text, |
| "source_id": idx, |
| "subject": doc["subject"], |
| "question": doc["question"], |
| "answer": doc["answer"], |
| }) |
| else: |
| start = 0 |
| while start < len(words): |
| end = min(start + chunk_size, len(words)) |
| chunk_text = " ".join(words[start:end]) |
| chunks.append({ |
| "text": chunk_text, |
| "source_id": idx, |
| "subject": doc["subject"], |
| "question": doc["question"], |
| "answer": doc["answer"], |
| }) |
| if end == len(words): |
| break |
| start += chunk_size - overlap |
| return chunks |
|
|
| chunks = chunk_documents(raw_docs) |
| print(f"Отримано {len(chunks)} чанків після chunking") |
|
|
| |
|
|
| tokenized_corpus = [c["text"].lower().split() for c in chunks] |
| bm25 = BM25Okapi(tokenized_corpus) |
|
|
| def bm25_search(query: str, top_k: int = 5): |
| tokenized_query = query.lower().split() |
| scores = bm25.get_scores(tokenized_query) |
| top_indices = np.argsort(scores)[::-1][:top_k] |
| return [(chunks[i], float(scores[i])) for i in top_indices if scores[i] > 0] |
|
|
| |
|
|
| print("Завантажуємо модель для семантичного пошуку...") |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
| print("Обчислюємо ембедінги для всіх чанків...") |
| chunk_texts = [c["text"] for c in chunks] |
| chunk_embeddings = embedder.encode(chunk_texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True) |
| print("Ембедінги готові!") |
|
|
| def semantic_search(query: str, top_k: int = 5): |
| query_emb = embedder.encode([query], convert_to_numpy=True)[0] |
| norms = np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_emb) |
| norms = np.where(norms == 0, 1e-9, norms) |
| scores = chunk_embeddings @ query_emb / norms |
| top_indices = np.argsort(scores)[::-1][:top_k] |
| return [(chunks[i], float(scores[i])) for i in top_indices] |
|
|
| |
|
|
| def hybrid_search(query: str, top_k: int = 5, use_bm25: bool = True, use_semantic: bool = True): |
| if not use_bm25 and not use_semantic: |
| return [] |
|
|
| results = {} |
|
|
| if use_bm25: |
| bm25_results = bm25_search(query, top_k=top_k * 2) |
| for rank, (chunk, score) in enumerate(bm25_results): |
| key = chunk["text"][:80] |
| results[key] = results.get(key, {"chunk": chunk, "score": 0}) |
| results[key]["score"] += 1 / (rank + 1) |
|
|
| if use_semantic: |
| sem_results = semantic_search(query, top_k=top_k * 2) |
| for rank, (chunk, score) in enumerate(sem_results): |
| key = chunk["text"][:80] |
| results[key] = results.get(key, {"chunk": chunk, "score": 0}) |
| results[key]["score"] += 1 / (rank + 1) |
|
|
| sorted_results = sorted(results.values(), key=lambda x: x["score"], reverse=True) |
| return [(r["chunk"], r["score"]) for r in sorted_results[:top_k]] |
|
|
| |
|
|
| def generate_answer(query: str, context_chunks: list, groq_api_key: str) -> str: |
| client = Groq(api_key=groq_api_key) |
|
|
| context_parts = [] |
| for i, (chunk, score) in enumerate(context_chunks, 1): |
| context_parts.append(f"[{i}] {chunk['text']}") |
| context = "\n\n".join(context_parts) |
|
|
| system_prompt = """You are a helpful medical assistant. Answer the user's question based ONLY on the provided context. |
| - Cite sources using square brackets like [1], [2] when you use information from them. |
| - If the context doesn't contain enough information, say so honestly. |
| - Be concise and accurate. |
| - Always mention relevant medical details from the context.""" |
|
|
| user_prompt = f"""Context: |
| {context} |
| |
| Question: {query} |
| |
| Answer (with citations like [1], [2]):""" |
|
|
| response = client.chat.completions.create( |
| model="llama-3.3-70b-versatile", |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| temperature=0.2, |
| max_tokens=600, |
| ) |
| return response.choices[0].message.content |
|
|
| |
|
|
| def rag_query( |
| query: str, |
| groq_api_key: str, |
| use_bm25: bool, |
| use_semantic: bool, |
| top_k: int, |
| ): |
| if not query.strip(): |
| return "⚠️ Введіть запитання.", "", "" |
| if not groq_api_key.strip(): |
| return "⚠️ Введіть Groq API ключ.", "", "" |
| if not use_bm25 and not use_semantic: |
| return "⚠️ Увімкніть хоча б один метод пошуку.", "", "" |
|
|
| try: |
| |
| retrieved = hybrid_search(query, top_k=top_k, use_bm25=use_bm25, use_semantic=use_semantic) |
|
|
| if not retrieved: |
| return "Не знайдено релевантних документів.", "", "" |
|
|
| |
| answer = generate_answer(query, retrieved, groq_api_key) |
|
|
| sources_md = "### Джерела (чанки)\n\n" |
| for i, (chunk, score) in enumerate(retrieved, 1): |
| subj = chunk.get("subject", "—") |
| sources_md += f"**[{i}]** *(Subject: {subj}, Score: {score:.4f})*\n\n" |
| sources_md += f"```\n{chunk['text'][:400]}{'...' if len(chunk['text']) > 400 else ''}\n```\n\n" |
|
|
| methods = [] |
| if use_bm25 and use_semantic: |
| methods.append("Hybrid (BM25 + Semantic)") |
| elif use_bm25: |
| methods.append("BM25 (keyword)") |
| else: |
| methods.append("Semantic (dense)") |
|
|
| info = f"**Метод пошуку:** {', '.join(methods)} | **Знайдено чанків:** {len(retrieved)}" |
|
|
| return answer, sources_md, info |
|
|
| except Exception as e: |
| return f"Помилка: {str(e)}", "", "" |
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="Medical RAG System", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| #Medical RAG System |
| **Retrieval-Augmented Generation** для відповідей на медичні запитання. |
| |
| Система використовує датасет **medmcqa** (медичні питання з іспитів) та поєднує |
| BM25 (пошук по ключових словах) і семантичний пошук для знаходження релевантних джерел. |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| query_input = gr.Textbox( |
| label="Ваше запитання", |
| placeholder="Напр.: What is the mechanism of action of aspirin?", |
| lines=2, |
| ) |
| api_key_input = gr.Textbox( |
| label="Groq API Key", |
| placeholder="gsk_...", |
| type="password", |
| ) |
|
|
| with gr.Column(scale=1): |
| use_bm25 = gr.Checkbox(label="BM25 (keyword search)", value=True) |
| use_semantic = gr.Checkbox(label="Semantic search", value=True) |
| top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Кількість чанків (top-k)") |
| submit_btn = gr.Button("🔍 Знайти відповідь", variant="primary") |
|
|
| info_box = gr.Markdown("") |
|
|
| with gr.Tabs(): |
| with gr.Tab("Відповідь"): |
| answer_output = gr.Markdown(label="Відповідь") |
| with gr.Tab("Джерела"): |
| sources_output = gr.Markdown(label="Використані чанки") |
|
|
| gr.Examples( |
| examples=[ |
| ["What is the mechanism of action of aspirin?", True, True, 5], |
| ["Which vitamin deficiency causes night blindness?", True, True, 5], |
| ["What are symptoms of diabetes mellitus?", False, True, 5], |
| ["beta blocker mechanism", True, False, 5], |
| ], |
| inputs=[query_input, use_bm25, use_semantic, top_k], |
| label="Приклади запитів", |
| ) |
|
|
| submit_btn.click( |
| fn=rag_query, |
| inputs=[query_input, api_key_input, use_bm25, use_semantic, top_k], |
| outputs=[answer_output, sources_output, info_box], |
| ) |
|
|
| gr.Markdown(""" |
| --- |
| **Датасет:** [medmcqa](https://huggingface.co/datasets/medmcqa) | |
| **LLM:** Groq (llama3-8b-8192) | |
| **Embeddings:** all-MiniLM-L6-v2 | |
| **Chunking:** sliding window (300 слів, overlap 50) |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |
|
|