| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| import torch |
| from utils import DocumentProcessor |
| import os |
| from typing import List, Dict |
|
|
| class HealthcareRAG: |
| def __init__(self, |
| model_name: str = "google/flan-t5-base", |
| index_path: str = "faiss_index.bin", |
| pdf_path: str = "clinical_guidelines.pdf"): |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) |
| |
| self.doc_processor = DocumentProcessor() |
| |
| |
| if os.path.exists(index_path): |
| self.index = self.doc_processor.load_index(index_path) |
| |
| self.chunks = [] |
| else: |
| self.chunks, self.index = self.doc_processor.process_document(pdf_path, index_path) |
| |
| def generate_response(self, query: str, retrieved_chunks: List[str]) -> str: |
| """Generate response using the LLM.""" |
| if not retrieved_chunks: |
| return "This information is not available in the current guidelines." |
| |
| |
| context = "\n".join(retrieved_chunks) |
| prompt = f"""Based on the following clinical guidelines, answer the question. |
| If the information is not explicitly stated in the guidelines, respond with "This information is not available in the current guidelines." |
| |
| Guidelines: |
| {context} |
| |
| Question: {query} |
| |
| Answer:""" |
| |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(self.device) |
| outputs = self.model.generate( |
| **inputs, |
| max_length=200, |
| num_beams=4, |
| temperature=0.7, |
| top_p=0.9, |
| do_sample=True |
| ) |
| |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if response and response != "This information is not available in the current guidelines.": |
| response += "\n\nDISCLAIMER: This information is for educational purposes only and not medical advice." |
| |
| return response |
| |
| def query(self, user_query: str) -> Dict[str, str]: |
| """Process user query and return response with retrieved chunks.""" |
| |
| retrieved_chunks = self.doc_processor.retrieve_chunks( |
| user_query, |
| self.index, |
| self.chunks |
| ) |
| |
| |
| response = self.generate_response(user_query, retrieved_chunks) |
| |
| return { |
| "response": response, |
| "retrieved_chunks": retrieved_chunks |
| } |