File size: 3,621 Bytes
a650118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import json
import numpy as np
import gradio as gr
import zipfile
import re
from openai import OpenAI
from pinecone import Pinecone
from utils import normalize_path, patient_info, qa_count, patient_number
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY'))
index = pc.Index('asco-guidelines')  # Use the asco-guidelines index

def get_embedding(text):
    response = client.embeddings.create(
        input=text,
        model="text-embedding-3-small"
    )
    return response.data[0].embedding

def query_pinecone(patient_description, user_question, top_k=3):
    patient_embedding = get_embedding(patient_description)
    question_embedding = get_embedding(user_question)
    combined_embedding = np.mean([patient_embedding, question_embedding], axis=0)
    combined_embedding_list = combined_embedding.tolist()
    results = index.query(vector=combined_embedding_list, top_k=top_k, include_metadata=True)
    relevant_guidelines = []
    for match in results['matches']:
        if 'metadata' in match and 'text' in match['metadata']:
            guideline_text = match['metadata']['text']
            relevant_guidelines.append({
                'text': guideline_text,
                'source': match['metadata'].get('source', 'Unknown source')
            })
        else:
            logger.warning(f"Expected metadata not found in match: {match}")
    return relevant_guidelines

def answer_question(patient_description, user_question, guidelines):
    context = "\n".join([f"Source: {guideline['source']}\n{guideline['text']}" for guideline in guidelines])
    prompt = f"""
    You are an AI assistant specialized in ASCO guidelines. Answer the following question based on the provided context.
    Patient Description: {patient_description}
    Question: {user_question}
    Context:
    {context}
    Please provide a detailed and accurate answer based on the patient description and the context. If the information is not sufficient to answer the question completely, please state so and provide the best possible answer with the available information.
    """
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a knowledgeable medical assistant providing detailed and accurate answers to questions about cancer treatments based on the provided context."},
            {"role": "user", "content": prompt}
        ],
    )
    return response.choices[0].message.content.strip()

def qa_tool_regular_rag(user_question):
    global patient_info, qa_count, patient_number
    try:
        relevant_guidelines = query_pinecone(patient_info["description"], user_question)
        if relevant_guidelines:
            answer = answer_question(patient_info["description"], user_question, relevant_guidelines)
            qa_count += 1
            logger.info(f"Answer generated.")
            logger.info(f"Relevant guidelines: {relevant_guidelines}")
            return answer, [], gr.update(visible=True), gr.update(visible=True)
        else:
            return "No relevant guidelines found with sufficient similarity.", [], gr.update(visible=True), gr.update(visible=True)
    except Exception as e:
        logger.error(f"Error in qa_tool_regular_rag: {str(e)}", exc_info=True)
        return f"An error occurred: {str(e)}", [], gr.update(visible=True), gr.update(visible=True)