NeuroCog-MD / MedQAChat.py
mattlicodes's picture
Upload folder using huggingface_hub
e881169 verified
import torch
import pandas as pd
import pickle
import json
import dgl
import networkx as nx
from huggingface_hub import InferenceClient
from typing import List, Dict, Tuple, Set
import numpy as np
from scipy.spatial.distance import cosine
from collections import defaultdict
from rapidfuzz import fuzz, process
class MedicalKGQASystem:
def __init__(
self,
graph_path: str,
node_info_path: str,
embeddings_path: str,
hf_token: str,
model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
fuzzy_threshold: int = 100, # Minimum similarity score (0-100)
device: str = None # Allow device to be specified
):
# Set device (GPU if available, else CPU)
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Load DGL graph and move to GPU
self.graph = dgl.load_graphs(graph_path)[0][0]
if self.device != 'cpu':
self.graph = self.graph.to(self.device)
# Load node information
self.node_info = pd.read_csv(node_info_path)
self.node_info.set_index('global_graph_index', inplace=True)
# Create reverse mappings
self.name_to_id = dict(zip(self.node_info['node_name'], self.node_info.index))
self.id_to_name = dict(zip(self.node_info.index, self.node_info['node_name']))
# Store all terms for fuzzy matching
self.all_terms = list(self.name_to_id.keys())
# Load embeddings and move to GPU
with open(embeddings_path, 'rb') as f:
embeddings = pickle.load(f)
# Convert embeddings to PyTorch tensor and move to GPU
self.embeddings = torch.tensor(embeddings, device=self.device)
# Initialize HuggingFace client
self.client = InferenceClient(model_name, token=hf_token)
self.nx_graph = dgl.to_networkx(self.graph.cpu())
# Set fuzzy matching threshold
self.fuzzy_threshold = fuzzy_threshold
# Initialize term cache for performance
self.term_match_cache = {}
def fuzzy_match_term(self, term: str) -> Tuple[str, float]:
"""Find the closest matching term in the knowledge graph."""
# Check cache first
if term in self.term_match_cache:
return self.term_match_cache[term]
# Find the best match using token sort ratio
match, score, _ = process.extractOne(
term,
self.all_terms,
scorer=fuzz.token_sort_ratio,
score_cutoff=self.fuzzy_threshold
) or (None, 0, None)
result = (match, score)
self.term_match_cache[term] = result
return result
def extract_medical_terms(self, text: str) -> List[Dict[str, any]]:
"""Extract medical terms using fuzzy matching."""
terms = []
text = text.lower()
words = text.split()
# Try different window sizes for multi-word terms
max_window = 5 # Maximum words in a medical term
i = 0
while i < len(words):
best_match = None
best_score = 0
best_window = 0
# Try different window sizes
for window in range(max_window, 0, -1):
if i + window <= len(words):
# Create potential term by joining words
potential_term = ' '.join(words[i:i + window])
# Try fuzzy matching
match, score = self.fuzzy_match_term(potential_term)
if match and score > best_score:
best_match = match
best_score = score
best_window = window
if best_match:
terms.append({
'original': ' '.join(words[i:i + best_window]),
'matched': best_match,
'similarity': best_score
})
i += best_window
else:
i += 1
return terms
def batch_cosine_similarity(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor) -> torch.Tensor:
"""Calculate batched cosine similarity between two sets of embeddings."""
# Normalize embeddings
normalized_emb1 = torch.nn.functional.normalize(embeddings1, p=2, dim=1)
normalized_emb2 = torch.nn.functional.normalize(embeddings2, p=2, dim=1)
# Calculate similarity
return torch.mm(normalized_emb1, normalized_emb2.t())
def get_relation_chains(
self,
matched_terms: List[Dict[str, any]],
max_depth: int = 5,
max_chains_per_term: int = 5
) -> List[List[str]]:
"""Extract relation chains using GPU-accelerated graph operations where possible."""
all_chains = []
# Convert matched term IDs to tensor
term_ids = torch.tensor(
[self.name_to_id[term['matched']] for term in matched_terms if term['matched'] in self.name_to_id],
device=self.device
)
def dfs(current_node: int, current_path: List[str], term_chains: List[List[str]], visited: Set[int]):
if len(current_path) >= max_depth or len(term_chains) >= max_chains_per_term:
return
# Get neighbors using DGL (GPU-accelerated)
neighbors = self.graph.successors(current_node).cpu().numpy()
# Calculate neighbor importance scores using embeddings (GPU-accelerated)
neighbor_embeddings = self.embeddings[neighbors]
current_embedding = self.embeddings[current_node].unsqueeze(0)
# Calculate similarities in batch
similarities = self.batch_cosine_similarity(current_embedding, neighbor_embeddings)
# Sort neighbors by importance score
neighbor_scores = zip(neighbors.tolist(), similarities[0].cpu().numpy())
sorted_neighbors = sorted(neighbor_scores, key=lambda x: x[1], reverse=True)
for neighbor, _ in sorted_neighbors:
if len(term_chains) >= max_chains_per_term:
break
if neighbor not in visited:
neighbor_name = self.id_to_name[neighbor]
new_path = current_path + [neighbor_name]
if len(new_path) >= 3:
term_chains.append(new_path)
visited.add(neighbor)
dfs(neighbor, new_path, term_chains, visited)
visited.remove(neighbor)
# Process each matched term
for term_info in matched_terms:
matched_term = term_info['matched']
if matched_term in self.name_to_id:
start_node = self.name_to_id[matched_term]
term_chains = []
visited = {start_node}
dfs(start_node, [matched_term], term_chains, visited)
all_chains.extend(term_chains[:max_chains_per_term])
all_chains.sort(key=len)
return all_chains
def construct_augmented_prompt(
self,
question: str,
matched_terms: List[Dict[str, any]],
relation_chains: List[List[str]],
embedding_neighbors: Dict[str, List[str]]
) -> str:
"""Construct an augmented prompt with knowledge graph information."""
prompt = f"""Question: {question}
Relevant medical knowledge:
Identified terms:"""
# Add matched terms with similarity scores
for term_info in matched_terms:
prompt += f"\n- '{term_info['original']}' matched to '{term_info['matched']}' (similarity: {term_info['similarity']}%)"
# Add relation chains
if relation_chains:
prompt += "\n\nRelated concept chains:\n"
for chain in relation_chains:
prompt += f"- {' -> '.join(chain)}\n"
# Add embedding neighbors
if embedding_neighbors:
prompt += "\nRelated concepts:\n"
for term, neighbors in embedding_neighbors.items():
prompt += f"- {term} is related to: {', '.join(neighbors)}\n"
prompt += "\nPlease provide a detailed answer with explanation based on the above knowledge.\n"
return prompt
def validate_response(
self,
response: str,
matched_terms: List[Dict[str, any]],
relation_chains: List[List[str]],
embedding_neighbors: Dict[str, List[str]]
) -> Tuple[bool, str]:
"""
Validate if the response aligns with the extracted knowledge using the LLM.
Args:
response: The model's response to validate
matched_terms: List of matched medical terms with their info
relation_chains: List of related concept chains
embedding_neighbors: Dictionary of related terms from embeddings
Returns:
Tuple of (is_valid, explanation)
"""
# Construct a validation prompt that presents the knowledge and asks for validation
validation_prompt = f"""Task: Validate if the given response aligns with and is supported by the provided medical knowledge.
Available Medical Knowledge:
1. Medical Terms Found:
{chr(10).join([f"- {term['matched']} (from term: {term['original']})" for term in matched_terms])}
2. Related Concept Chains:
{chr(10).join([f"- {' -> '.join(chain)}" for chain in relation_chains])}
3. Related Concepts:
{chr(10).join([f"- {term} is related to: {', '.join(neighbors)}" for term, neighbors in embedding_neighbors.items()])}
Response to Validate:
{response}
Please analyze if the response:
1. Makes claims that are supported by the provided knowledge
2. Uses medical terms and relationships that appear in the knowledge
3. Doesn't contradict any relationships shown in the concept chains
4. Stays within the scope of the available knowledge
Provide your analysis in the following format:
- Valid (Yes/No):
- Explanation:
- Specific Issues (if any):
Analysis:"""
# Get validation from the model
validation_result = self.client.text_generation(
validation_prompt,
max_new_tokens=512,
temperature=0.3, # Lower temperature for more consistent validation
top_p=0.95
)
# Parse the validation result
try:
# Extract the Valid (Yes/No) line
is_valid = False
explanation = ""
issues = ""
for line in validation_result.split('\n'):
if line.startswith('- Valid'):
is_valid = 'yes' in line.lower()
elif line.startswith('- Explanation'):
explanation = line.replace('- Explanation:', '').strip()
elif line.startswith('- Specific Issues'):
issues = line.replace('- Specific Issues:', '').strip()
if issues and not is_valid:
return False, f"{explanation} Issues: {issues}"
return is_valid, explanation
except Exception as e:
# Fallback in case of parsing issues
return False, f"Validation parsing failed. Raw validation result: {validation_result}"
def process_question(self, question: str) -> Dict:
"""Process a single question through the pipeline."""
# Extract medical terms with fuzzy matching
matched_terms = self.extract_medical_terms(question)
# Get knowledge graph information
relation_chains = self.get_relation_chains(matched_terms)
embedding_neighbors = self.get_embedding_neighbors(matched_terms)
# Construct augmented prompt
prompt = self.construct_augmented_prompt(
question,
matched_terms,
relation_chains,
embedding_neighbors
)
# Get model response
response = self.client.text_generation(
prompt,
max_new_tokens=512,
temperature=0.7,
top_p=0.95
)
# Validate response using the model
is_valid, validation_message = self.validate_response(
response,
matched_terms,
relation_chains,
embedding_neighbors
)
return {
"question": question,
"answer": response,
"is_valid": is_valid,
"validation_message": validation_message,
"knowledge_used": {
"matched_terms": matched_terms,
"relation_chains": relation_chains,
"embedding_neighbors": embedding_neighbors
}
}
def evaluate_medqa(self, test_file: str) -> Dict:
"""Evaluate the system on MedQA test set."""
results = []
correct = 0
total = 0
with open(test_file, 'r') as f:
for line in f:
data = json.loads(line)
question = data['question']
true_answer = data['options'][data['answer_idx']]
result = self.process_question(question)
predicted_answer = result['answer']
# Simple exact match accuracy
if predicted_answer.strip() == true_answer.strip():
correct += 1
total += 1
results.append({
"question": question,
"true_answer": true_answer,
"predicted_answer": predicted_answer,
"is_correct": predicted_answer.strip() == true_answer.strip(),
"validation_info": {
"is_valid": result['is_valid'],
"validation_message": result['validation_message']
}
})
return {
"accuracy": correct / total,
"total_questions": total,
"detailed_results": results
}
def simple_validate_response(
client,
response: str,
matched_terms: List[Dict[str, any]],
relation_chains: List[List[str]],
embedding_neighbors: Dict[str, List[str]]
) -> Tuple[bool, str]:
"""Simple sanity check for response against knowledge."""
validation_prompt = f"""Given this medical knowledge:
Medical Terms: {', '.join(term['matched'] for term in matched_terms)}
Relationship Chains:
{chr(10).join([' -> '.join(chain) for chain in relation_chains])}
And this response:
{response}
Does the response contain any information that conflicts with or contradicts the provided medical knowledge? Reply with just "VALID" if there are no conflicts, or explain the specific conflict if you find one."""
validation_result = client.text_generation(
validation_prompt,
max_new_tokens=256,
temperature=0.3
).strip()
is_valid = validation_result == "VALID"
return is_valid, "" if is_valid else validation_result
import gradio as gr
from typing import List, Dict
import json
def summarize_answer_relationships(
client,
answer: str,
relation_chains: List[List[str]]
) -> str:
"""Generate a summary of how the answer relates to the knowledge chains."""
summary_prompt = f"""Based on these relationship chains from our medical knowledge graph:
{chr(10).join(['→ ' + ' -> '.join(chain) for chain in relation_chains])}
And this answer:
{answer}
Provide a brief (2-3 sentences) summary of how the answer uses or relates to these medical relationships. Focus only on the clear connections between the answer and the relationship chains shown above."""
return client.text_generation(
summary_prompt,
max_new_tokens=200,
temperature=0.3
).strip()
def create_chat_demo(qa_system: MedicalKGQASystem) -> gr.Interface:
"""Create a Gradio chat interface for the Medical QA system."""
def format_knowledge(
knowledge: Dict,
answer: str,
client
) -> tuple[str, str, str]:
"""Format knowledge into three sections: terms, relationships, and summary."""
# Format relation chains
chains_output = []
if knowledge.get('relation_chains'):
chains_output.append("🔗 Key Medical Relationships:")
chains = knowledge['relation_chains']
chains.reverse()
for chain in chains:
chains_output.append(f"• {' → '.join(chain)}")
# Generate relationship summary
if knowledge.get('relation_chains') and answer:
summary = summarize_answer_relationships(
client,
answer,
knowledge['relation_chains']
)
else:
summary = "No relationship chains available for summary."
terms_output = []
if knowledge.get('matched_terms'):
terms_output.append("📌 Matched Terms:")
for term in knowledge['matched_terms']:
terms_output.append(f"• {term['original']}{term['matched']}")
return (
"\n".join(terms_output),
"\n".join(chains_output),
summary
)
def process_chat(
message: str,
history: List[List[str]]
) -> tuple[List[List[str]], str, str, str, str]:
"""Process chat messages and return response with knowledge details."""
# Process the question
result = qa_system.process_question(message)
# Validate response
is_valid, validation_msg = simple_validate_response(
qa_system.client,
result['answer'],
result['knowledge_used']['matched_terms'],
result['knowledge_used']['relation_chains'],
result['knowledge_used']['embedding_neighbors']
)
# Format the response
answer = result['answer']
if not is_valid:
answer += f"\n\n⚠️ Potential Conflict: {validation_msg}"
# Update chat history
history.append([message, answer])
# Format knowledge sections
terms, chains, summary = format_knowledge(
result['knowledge_used'],
answer,
qa_system.client
)
return history, terms, chains, summary, "✅ Valid" if is_valid else f"⚠️ {validation_msg}"
# Create the Gradio interface
with gr.Blocks(title="Medical QA System") as demo:
gr.Markdown("""
# 🏥 Medical Knowledge Graph QA
Ask medical questions and get knowledge-grounded answers
""")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Chat History",
height=500
)
msg = gr.Textbox(
label="Your Question",
placeholder="Ask a medical question...",
lines=2
)
clear = gr.Button("Clear Chat")
with gr.Column(scale=1):
with gr.Accordion("Knowledge Used", open=True):
chains_box = gr.Textbox(
label="Relationship Chains",
lines=6,
)
summary_box = gr.Textbox(
label="💡 Knowledge Summary",
lines=4,
)
terms_box = gr.Textbox(
label="Medical Terms",
lines=4,
)
validation_box = gr.Textbox(
label="🔍 Validation Status",
lines=2,
)
# Handle message submission
msg.submit(
process_chat,
[msg, chatbot],
[chatbot, terms_box, chains_box, summary_box, validation_box]
)
# Handle clear button
clear.click(
lambda: ([], "", "", "", ""),
None,
[chatbot, terms_box, chains_box, summary_box, validation_box]
)
return demo
import os
# Usage example:
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
system = MedicalKGQASystem(
graph_path="new_homo_hg_hms.pt",
node_info_path="new_node_map_df.csv",
embeddings_path="full_h_embed_hms.pkl",
hf_token=os.environ.get("HF_KEY", None),
fuzzy_threshold=100,
device=device
)
# # Process single question
# question = "What are the potential side effects of metformin?"
# result = system.process_question(question)
# print(f"Answer: {result['answer']}")
# print(f"Validation: {result['validation_message']}")
# Evaluate on test set
# eval_results = system.evaluate_medqa("test.jsonl")
# print(f"Overall accuracy: {eval_results['accuracy']:.2f}")
# Create and launch the demo
demo = create_chat_demo(system)
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)