ASCO-based-QA / graph_rag.py
AbdulMoid's picture
Update graph_rag.py
dcb46b0 verified
import os
import zipfile
import re
import shutil
import logging
import subprocess
from dotenv import load_dotenv
import gradio as gr
from utils import patient_info # Importing patient_info from utils
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
def unzip_folder(zip_path, extract_path):
output_dir = os.path.join(extract_path, "ragtest")
os.makedirs(output_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
# Adjust to handle the extra ragtest folder inside the zip
actual_output_dir = os.path.join(output_dir, "ragtest")
if os.path.exists(actual_output_dir):
logger.info(f"Extracted contents to {actual_output_dir}")
logger.info(f"Contents of {actual_output_dir}:")
for file in os.listdir(actual_output_dir):
logger.info(os.path.join(actual_output_dir, file))
else:
logger.error(f"Expected directory {actual_output_dir} does not exist. Check the structure of the zip file.")
actual_output_dir = output_dir # fallback in case the structure is not as expected
return actual_output_dir
def run_graphrag_query(query, ragtest_dir):
# Log the directory and its contents
logger.info(f"Running GraphRAG query with root: {ragtest_dir}")
logger.info(f"Contents of {ragtest_dir}:")
for file in os.listdir(ragtest_dir):
logger.info(os.path.join(ragtest_dir, file))
# Define the command
command = [
"python", "-m", "graphrag.query",
"--root", ragtest_dir,
"--method", "global",
query
]
# Run the command
result = subprocess.run(command, capture_output=True, text=True)
# Return the output or error message
if result.returncode == 0:
return result.stdout
else:
logger.error(f"GraphRAG query failed with error: {result.stderr}")
return result.stderr
def clean_response(response):
# Find the position of "SUCCESS: Global Search Response:"
search_str = "SUCCESS: Global Search Response:"
start_index = response.find(search_str)
# If the search string is found, return the substring starting from after this string
if start_index != -1:
# Add the length of search_str to start_index to begin after the search string
return response[start_index + len(search_str):].strip()
else:
# If the search string is not found, return the original response
return response
def qa_tool_graph_rag(user_question):
original_dir = os.getcwd() # Store the original directory
try:
zip_path = os.getenv('ZIP_PATH', '/home/user/app/ragtest.zip')
extract_path = os.getenv('EXTRACT_PATH', '/home/user/app')
output_dir = unzip_folder(zip_path, extract_path)
os.chdir(extract_path)
# Combine patient_info with user_question
combined_input = f"{patient_info}\n\n{user_question}"
# Run the GraphRAG query with the combined input
raw_answer = run_graphrag_query(combined_input, output_dir)
# Clean the response to remove everything before "SUCCESS: Global Search Response:"
answer = clean_response(raw_answer)
logger.info(f"GraphRAG answer generated: {answer}")
images = [] # Adjust as needed for your application
return answer, images, gr.update(visible=True), gr.update(visible=True)
except Exception as e:
logger.error(f"Error in GraphRAG processing: {str(e)}")
return f"An error occurred: {str(e)}", [], gr.update(visible=False), gr.update(visible=False)
finally:
if 'output_dir' in locals():
shutil.rmtree(output_dir)
os.chdir(original_dir)