""" Medical RAG System """ import gradio as gr import numpy as np import json import re from typing import Optional from datasets import load_dataset from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from groq import Groq # 1. ЗАВАНТАЖЕННЯ ДАНИХ print("Завантажуємо датасет medmcqa...") dataset = load_dataset("medmcqa", split="validation[:500]", trust_remote_code=True) # Формуємо документи з питань + пояснень raw_docs = [] for item in dataset: explanation = item.get("exp") or "" question = item["question"] options = [item.get(f"op{k}", "") for k in ["a", "b", "c", "d"]] correct_key = ["a", "b", "c", "d"][item["cop"]] correct_answer = item.get(f"op{correct_key}", "") subject = item.get("subject_name", "") text = f"Question: {question}\n" text += f"Options: A) {options[0]} B) {options[1]} C) {options[2]} D) {options[3]}\n" text += f"Answer: {correct_answer}\n" if explanation: text += f"Explanation: {explanation}" raw_docs.append({ "text": text, "subject": subject, "question": question, "answer": correct_answer, }) print(f"Завантажено {len(raw_docs)} документів") # 2. CHUNKING def chunk_documents(docs, chunk_size=300, overlap=50): chunks = [] for idx, doc in enumerate(docs): text = doc["text"] words = text.split() if len(words) <= chunk_size: chunks.append({ "text": text, "source_id": idx, "subject": doc["subject"], "question": doc["question"], "answer": doc["answer"], }) else: start = 0 while start < len(words): end = min(start + chunk_size, len(words)) chunk_text = " ".join(words[start:end]) chunks.append({ "text": chunk_text, "source_id": idx, "subject": doc["subject"], "question": doc["question"], "answer": doc["answer"], }) if end == len(words): break start += chunk_size - overlap return chunks chunks = chunk_documents(raw_docs) print(f"Отримано {len(chunks)} чанків після chunking") # 3. BM25 RETRIEVER tokenized_corpus = [c["text"].lower().split() for c in chunks] bm25 = BM25Okapi(tokenized_corpus) def bm25_search(query: str, top_k: int = 5): tokenized_query = query.lower().split() scores = bm25.get_scores(tokenized_query) top_indices = np.argsort(scores)[::-1][:top_k] return [(chunks[i], float(scores[i])) for i in top_indices if scores[i] > 0] # 4. SEMANTIC (DENSE) RETRIEVER print("Завантажуємо модель для семантичного пошуку...") embedder = SentenceTransformer("all-MiniLM-L6-v2") print("Обчислюємо ембедінги для всіх чанків...") chunk_texts = [c["text"] for c in chunks] chunk_embeddings = embedder.encode(chunk_texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True) print("Ембедінги готові!") def semantic_search(query: str, top_k: int = 5): query_emb = embedder.encode([query], convert_to_numpy=True)[0] norms = np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_emb) norms = np.where(norms == 0, 1e-9, norms) scores = chunk_embeddings @ query_emb / norms top_indices = np.argsort(scores)[::-1][:top_k] return [(chunks[i], float(scores[i])) for i in top_indices] # 5. HYBRID RETRIEVER def hybrid_search(query: str, top_k: int = 5, use_bm25: bool = True, use_semantic: bool = True): if not use_bm25 and not use_semantic: return [] results = {} if use_bm25: bm25_results = bm25_search(query, top_k=top_k * 2) for rank, (chunk, score) in enumerate(bm25_results): key = chunk["text"][:80] results[key] = results.get(key, {"chunk": chunk, "score": 0}) results[key]["score"] += 1 / (rank + 1) # reciprocal rank fusion if use_semantic: sem_results = semantic_search(query, top_k=top_k * 2) for rank, (chunk, score) in enumerate(sem_results): key = chunk["text"][:80] results[key] = results.get(key, {"chunk": chunk, "score": 0}) results[key]["score"] += 1 / (rank + 1) sorted_results = sorted(results.values(), key=lambda x: x["score"], reverse=True) return [(r["chunk"], r["score"]) for r in sorted_results[:top_k]] # 6. LLM (GROQ) def generate_answer(query: str, context_chunks: list, groq_api_key: str) -> str: client = Groq(api_key=groq_api_key) context_parts = [] for i, (chunk, score) in enumerate(context_chunks, 1): context_parts.append(f"[{i}] {chunk['text']}") context = "\n\n".join(context_parts) system_prompt = """You are a helpful medical assistant. Answer the user's question based ONLY on the provided context. - Cite sources using square brackets like [1], [2] when you use information from them. - If the context doesn't contain enough information, say so honestly. - Be concise and accurate. - Always mention relevant medical details from the context.""" user_prompt = f"""Context: {context} Question: {query} Answer (with citations like [1], [2]):""" response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.2, max_tokens=600, ) return response.choices[0].message.content # 7. ГОЛОВНА ФУНКЦІЯ RAG def rag_query( query: str, groq_api_key: str, use_bm25: bool, use_semantic: bool, top_k: int, ): if not query.strip(): return "⚠️ Введіть запитання.", "", "" if not groq_api_key.strip(): return "⚠️ Введіть Groq API ключ.", "", "" if not use_bm25 and not use_semantic: return "⚠️ Увімкніть хоча б один метод пошуку.", "", "" try: # Пошук релевантних чанків retrieved = hybrid_search(query, top_k=top_k, use_bm25=use_bm25, use_semantic=use_semantic) if not retrieved: return "Не знайдено релевантних документів.", "", "" # Генерація відповіді answer = generate_answer(query, retrieved, groq_api_key) sources_md = "### Джерела (чанки)\n\n" for i, (chunk, score) in enumerate(retrieved, 1): subj = chunk.get("subject", "—") sources_md += f"**[{i}]** *(Subject: {subj}, Score: {score:.4f})*\n\n" sources_md += f"```\n{chunk['text'][:400]}{'...' if len(chunk['text']) > 400 else ''}\n```\n\n" methods = [] if use_bm25 and use_semantic: methods.append("Hybrid (BM25 + Semantic)") elif use_bm25: methods.append("BM25 (keyword)") else: methods.append("Semantic (dense)") info = f"**Метод пошуку:** {', '.join(methods)} | **Знайдено чанків:** {len(retrieved)}" return answer, sources_md, info except Exception as e: return f"Помилка: {str(e)}", "", "" # ────────────────────────────────────────────── # 8. GRADIO UI # ────────────────────────────────────────────── with gr.Blocks(title="Medical RAG System", theme=gr.themes.Soft()) as demo: gr.Markdown(""" #Medical RAG System **Retrieval-Augmented Generation** для відповідей на медичні запитання. Система використовує датасет **medmcqa** (медичні питання з іспитів) та поєднує BM25 (пошук по ключових словах) і семантичний пошук для знаходження релевантних джерел. """) with gr.Row(): with gr.Column(scale=2): query_input = gr.Textbox( label="Ваше запитання", placeholder="Напр.: What is the mechanism of action of aspirin?", lines=2, ) api_key_input = gr.Textbox( label="Groq API Key", placeholder="gsk_...", type="password", ) with gr.Column(scale=1): use_bm25 = gr.Checkbox(label="BM25 (keyword search)", value=True) use_semantic = gr.Checkbox(label="Semantic search", value=True) top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Кількість чанків (top-k)") submit_btn = gr.Button("🔍 Знайти відповідь", variant="primary") info_box = gr.Markdown("") with gr.Tabs(): with gr.Tab("Відповідь"): answer_output = gr.Markdown(label="Відповідь") with gr.Tab("Джерела"): sources_output = gr.Markdown(label="Використані чанки") gr.Examples( examples=[ ["What is the mechanism of action of aspirin?", True, True, 5], ["Which vitamin deficiency causes night blindness?", True, True, 5], ["What are symptoms of diabetes mellitus?", False, True, 5], ["beta blocker mechanism", True, False, 5], ], inputs=[query_input, use_bm25, use_semantic, top_k], label="Приклади запитів", ) submit_btn.click( fn=rag_query, inputs=[query_input, api_key_input, use_bm25, use_semantic, top_k], outputs=[answer_output, sources_output, info_box], ) gr.Markdown(""" --- **Датасет:** [medmcqa](https://huggingface.co/datasets/medmcqa) | **LLM:** Groq (llama3-8b-8192) | **Embeddings:** all-MiniLM-L6-v2 | **Chunking:** sliding window (300 слів, overlap 50) """) if __name__ == "__main__": demo.launch(share=True)