AbdulMoid commited on
Commit
a650118
·
verified ·
1 Parent(s): 1704015

Create regular_rag.py

Browse files
Files changed (1) hide show
  1. regular_rag.py +78 -0
regular_rag.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import gradio as gr
5
+ import zipfile
6
+ import re
7
+ from openai import OpenAI
8
+ from pinecone import Pinecone
9
+ from utils import normalize_path, patient_info, qa_count, patient_number
10
+ import logging
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
17
+ pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY'))
18
+ index = pc.Index('asco-guidelines') # Use the asco-guidelines index
19
+
20
+ def get_embedding(text):
21
+ response = client.embeddings.create(
22
+ input=text,
23
+ model="text-embedding-3-small"
24
+ )
25
+ return response.data[0].embedding
26
+
27
+ def query_pinecone(patient_description, user_question, top_k=3):
28
+ patient_embedding = get_embedding(patient_description)
29
+ question_embedding = get_embedding(user_question)
30
+ combined_embedding = np.mean([patient_embedding, question_embedding], axis=0)
31
+ combined_embedding_list = combined_embedding.tolist()
32
+ results = index.query(vector=combined_embedding_list, top_k=top_k, include_metadata=True)
33
+ relevant_guidelines = []
34
+ for match in results['matches']:
35
+ if 'metadata' in match and 'text' in match['metadata']:
36
+ guideline_text = match['metadata']['text']
37
+ relevant_guidelines.append({
38
+ 'text': guideline_text,
39
+ 'source': match['metadata'].get('source', 'Unknown source')
40
+ })
41
+ else:
42
+ logger.warning(f"Expected metadata not found in match: {match}")
43
+ return relevant_guidelines
44
+
45
+ def answer_question(patient_description, user_question, guidelines):
46
+ context = "\n".join([f"Source: {guideline['source']}\n{guideline['text']}" for guideline in guidelines])
47
+ prompt = f"""
48
+ You are an AI assistant specialized in ASCO guidelines. Answer the following question based on the provided context.
49
+ Patient Description: {patient_description}
50
+ Question: {user_question}
51
+ Context:
52
+ {context}
53
+ Please provide a detailed and accurate answer based on the patient description and the context. If the information is not sufficient to answer the question completely, please state so and provide the best possible answer with the available information.
54
+ """
55
+ response = client.chat.completions.create(
56
+ model="gpt-4",
57
+ messages=[
58
+ {"role": "system", "content": "You are a knowledgeable medical assistant providing detailed and accurate answers to questions about cancer treatments based on the provided context."},
59
+ {"role": "user", "content": prompt}
60
+ ],
61
+ )
62
+ return response.choices[0].message.content.strip()
63
+
64
+ def qa_tool_regular_rag(user_question):
65
+ global patient_info, qa_count, patient_number
66
+ try:
67
+ relevant_guidelines = query_pinecone(patient_info["description"], user_question)
68
+ if relevant_guidelines:
69
+ answer = answer_question(patient_info["description"], user_question, relevant_guidelines)
70
+ qa_count += 1
71
+ logger.info(f"Answer generated.")
72
+ logger.info(f"Relevant guidelines: {relevant_guidelines}")
73
+ return answer, [], gr.update(visible=True), gr.update(visible=True)
74
+ else:
75
+ return "No relevant guidelines found with sufficient similarity.", [], gr.update(visible=True), gr.update(visible=True)
76
+ except Exception as e:
77
+ logger.error(f"Error in qa_tool_regular_rag: {str(e)}", exc_info=True)
78
+ return f"An error occurred: {str(e)}", [], gr.update(visible=True), gr.update(visible=True)