Spaces:
Sleeping
Sleeping
| import torch | |
| import pandas as pd | |
| import pickle | |
| import json | |
| import dgl | |
| import networkx as nx | |
| from huggingface_hub import InferenceClient | |
| from typing import List, Dict, Tuple, Set | |
| import numpy as np | |
| from scipy.spatial.distance import cosine | |
| from collections import defaultdict | |
| from rapidfuzz import fuzz, process | |
| class MedicalKGQASystem: | |
| def __init__( | |
| self, | |
| graph_path: str, | |
| node_info_path: str, | |
| embeddings_path: str, | |
| hf_token: str, | |
| model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", | |
| fuzzy_threshold: int = 100, # Minimum similarity score (0-100) | |
| device: str = None # Allow device to be specified | |
| ): | |
| # Set device (GPU if available, else CPU) | |
| self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {self.device}") | |
| # Load DGL graph and move to GPU | |
| self.graph = dgl.load_graphs(graph_path)[0][0] | |
| if self.device != 'cpu': | |
| self.graph = self.graph.to(self.device) | |
| # Load node information | |
| self.node_info = pd.read_csv(node_info_path) | |
| self.node_info.set_index('global_graph_index', inplace=True) | |
| # Create reverse mappings | |
| self.name_to_id = dict(zip(self.node_info['node_name'], self.node_info.index)) | |
| self.id_to_name = dict(zip(self.node_info.index, self.node_info['node_name'])) | |
| # Store all terms for fuzzy matching | |
| self.all_terms = list(self.name_to_id.keys()) | |
| # Load embeddings and move to GPU | |
| with open(embeddings_path, 'rb') as f: | |
| embeddings = pickle.load(f) | |
| # Convert embeddings to PyTorch tensor and move to GPU | |
| self.embeddings = torch.tensor(embeddings, device=self.device) | |
| # Initialize HuggingFace client | |
| self.client = InferenceClient(model_name, token=hf_token) | |
| self.nx_graph = dgl.to_networkx(self.graph.cpu()) | |
| # Set fuzzy matching threshold | |
| self.fuzzy_threshold = fuzzy_threshold | |
| # Initialize term cache for performance | |
| self.term_match_cache = {} | |
| def fuzzy_match_term(self, term: str) -> Tuple[str, float]: | |
| """Find the closest matching term in the knowledge graph.""" | |
| # Check cache first | |
| if term in self.term_match_cache: | |
| return self.term_match_cache[term] | |
| # Find the best match using token sort ratio | |
| match, score, _ = process.extractOne( | |
| term, | |
| self.all_terms, | |
| scorer=fuzz.token_sort_ratio, | |
| score_cutoff=self.fuzzy_threshold | |
| ) or (None, 0, None) | |
| result = (match, score) | |
| self.term_match_cache[term] = result | |
| return result | |
| def extract_medical_terms(self, text: str) -> List[Dict[str, any]]: | |
| """Extract medical terms using fuzzy matching.""" | |
| terms = [] | |
| text = text.lower() | |
| words = text.split() | |
| # Try different window sizes for multi-word terms | |
| max_window = 5 # Maximum words in a medical term | |
| i = 0 | |
| while i < len(words): | |
| best_match = None | |
| best_score = 0 | |
| best_window = 0 | |
| # Try different window sizes | |
| for window in range(max_window, 0, -1): | |
| if i + window <= len(words): | |
| # Create potential term by joining words | |
| potential_term = ' '.join(words[i:i + window]) | |
| # Try fuzzy matching | |
| match, score = self.fuzzy_match_term(potential_term) | |
| if match and score > best_score: | |
| best_match = match | |
| best_score = score | |
| best_window = window | |
| if best_match: | |
| terms.append({ | |
| 'original': ' '.join(words[i:i + best_window]), | |
| 'matched': best_match, | |
| 'similarity': best_score | |
| }) | |
| i += best_window | |
| else: | |
| i += 1 | |
| return terms | |
| def batch_cosine_similarity(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor) -> torch.Tensor: | |
| """Calculate batched cosine similarity between two sets of embeddings.""" | |
| # Normalize embeddings | |
| normalized_emb1 = torch.nn.functional.normalize(embeddings1, p=2, dim=1) | |
| normalized_emb2 = torch.nn.functional.normalize(embeddings2, p=2, dim=1) | |
| # Calculate similarity | |
| return torch.mm(normalized_emb1, normalized_emb2.t()) | |
| def get_relation_chains( | |
| self, | |
| matched_terms: List[Dict[str, any]], | |
| max_depth: int = 5, | |
| max_chains_per_term: int = 5 | |
| ) -> List[List[str]]: | |
| """Extract relation chains using GPU-accelerated graph operations where possible.""" | |
| all_chains = [] | |
| # Convert matched term IDs to tensor | |
| term_ids = torch.tensor( | |
| [self.name_to_id[term['matched']] for term in matched_terms if term['matched'] in self.name_to_id], | |
| device=self.device | |
| ) | |
| def dfs(current_node: int, current_path: List[str], term_chains: List[List[str]], visited: Set[int]): | |
| if len(current_path) >= max_depth or len(term_chains) >= max_chains_per_term: | |
| return | |
| # Get neighbors using DGL (GPU-accelerated) | |
| neighbors = self.graph.successors(current_node).cpu().numpy() | |
| # Calculate neighbor importance scores using embeddings (GPU-accelerated) | |
| neighbor_embeddings = self.embeddings[neighbors] | |
| current_embedding = self.embeddings[current_node].unsqueeze(0) | |
| # Calculate similarities in batch | |
| similarities = self.batch_cosine_similarity(current_embedding, neighbor_embeddings) | |
| # Sort neighbors by importance score | |
| neighbor_scores = zip(neighbors.tolist(), similarities[0].cpu().numpy()) | |
| sorted_neighbors = sorted(neighbor_scores, key=lambda x: x[1], reverse=True) | |
| for neighbor, _ in sorted_neighbors: | |
| if len(term_chains) >= max_chains_per_term: | |
| break | |
| if neighbor not in visited: | |
| neighbor_name = self.id_to_name[neighbor] | |
| new_path = current_path + [neighbor_name] | |
| if len(new_path) >= 3: | |
| term_chains.append(new_path) | |
| visited.add(neighbor) | |
| dfs(neighbor, new_path, term_chains, visited) | |
| visited.remove(neighbor) | |
| # Process each matched term | |
| for term_info in matched_terms: | |
| matched_term = term_info['matched'] | |
| if matched_term in self.name_to_id: | |
| start_node = self.name_to_id[matched_term] | |
| term_chains = [] | |
| visited = {start_node} | |
| dfs(start_node, [matched_term], term_chains, visited) | |
| all_chains.extend(term_chains[:max_chains_per_term]) | |
| all_chains.sort(key=len) | |
| return all_chains | |
| def construct_augmented_prompt( | |
| self, | |
| question: str, | |
| matched_terms: List[Dict[str, any]], | |
| relation_chains: List[List[str]], | |
| embedding_neighbors: Dict[str, List[str]] | |
| ) -> str: | |
| """Construct an augmented prompt with knowledge graph information.""" | |
| prompt = f"""Question: {question} | |
| Relevant medical knowledge: | |
| Identified terms:""" | |
| # Add matched terms with similarity scores | |
| for term_info in matched_terms: | |
| prompt += f"\n- '{term_info['original']}' matched to '{term_info['matched']}' (similarity: {term_info['similarity']}%)" | |
| # Add relation chains | |
| if relation_chains: | |
| prompt += "\n\nRelated concept chains:\n" | |
| for chain in relation_chains: | |
| prompt += f"- {' -> '.join(chain)}\n" | |
| # Add embedding neighbors | |
| if embedding_neighbors: | |
| prompt += "\nRelated concepts:\n" | |
| for term, neighbors in embedding_neighbors.items(): | |
| prompt += f"- {term} is related to: {', '.join(neighbors)}\n" | |
| prompt += "\nPlease provide a detailed answer with explanation based on the above knowledge.\n" | |
| return prompt | |
| def validate_response( | |
| self, | |
| response: str, | |
| matched_terms: List[Dict[str, any]], | |
| relation_chains: List[List[str]], | |
| embedding_neighbors: Dict[str, List[str]] | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Validate if the response aligns with the extracted knowledge using the LLM. | |
| Args: | |
| response: The model's response to validate | |
| matched_terms: List of matched medical terms with their info | |
| relation_chains: List of related concept chains | |
| embedding_neighbors: Dictionary of related terms from embeddings | |
| Returns: | |
| Tuple of (is_valid, explanation) | |
| """ | |
| # Construct a validation prompt that presents the knowledge and asks for validation | |
| validation_prompt = f"""Task: Validate if the given response aligns with and is supported by the provided medical knowledge. | |
| Available Medical Knowledge: | |
| 1. Medical Terms Found: | |
| {chr(10).join([f"- {term['matched']} (from term: {term['original']})" for term in matched_terms])} | |
| 2. Related Concept Chains: | |
| {chr(10).join([f"- {' -> '.join(chain)}" for chain in relation_chains])} | |
| 3. Related Concepts: | |
| {chr(10).join([f"- {term} is related to: {', '.join(neighbors)}" for term, neighbors in embedding_neighbors.items()])} | |
| Response to Validate: | |
| {response} | |
| Please analyze if the response: | |
| 1. Makes claims that are supported by the provided knowledge | |
| 2. Uses medical terms and relationships that appear in the knowledge | |
| 3. Doesn't contradict any relationships shown in the concept chains | |
| 4. Stays within the scope of the available knowledge | |
| Provide your analysis in the following format: | |
| - Valid (Yes/No): | |
| - Explanation: | |
| - Specific Issues (if any): | |
| Analysis:""" | |
| # Get validation from the model | |
| validation_result = self.client.text_generation( | |
| validation_prompt, | |
| max_new_tokens=512, | |
| temperature=0.3, # Lower temperature for more consistent validation | |
| top_p=0.95 | |
| ) | |
| # Parse the validation result | |
| try: | |
| # Extract the Valid (Yes/No) line | |
| is_valid = False | |
| explanation = "" | |
| issues = "" | |
| for line in validation_result.split('\n'): | |
| if line.startswith('- Valid'): | |
| is_valid = 'yes' in line.lower() | |
| elif line.startswith('- Explanation'): | |
| explanation = line.replace('- Explanation:', '').strip() | |
| elif line.startswith('- Specific Issues'): | |
| issues = line.replace('- Specific Issues:', '').strip() | |
| if issues and not is_valid: | |
| return False, f"{explanation} Issues: {issues}" | |
| return is_valid, explanation | |
| except Exception as e: | |
| # Fallback in case of parsing issues | |
| return False, f"Validation parsing failed. Raw validation result: {validation_result}" | |
| def process_question(self, question: str) -> Dict: | |
| """Process a single question through the pipeline.""" | |
| # Extract medical terms with fuzzy matching | |
| matched_terms = self.extract_medical_terms(question) | |
| # Get knowledge graph information | |
| relation_chains = self.get_relation_chains(matched_terms) | |
| embedding_neighbors = self.get_embedding_neighbors(matched_terms) | |
| # Construct augmented prompt | |
| prompt = self.construct_augmented_prompt( | |
| question, | |
| matched_terms, | |
| relation_chains, | |
| embedding_neighbors | |
| ) | |
| # Get model response | |
| response = self.client.text_generation( | |
| prompt, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.95 | |
| ) | |
| # Validate response using the model | |
| is_valid, validation_message = self.validate_response( | |
| response, | |
| matched_terms, | |
| relation_chains, | |
| embedding_neighbors | |
| ) | |
| return { | |
| "question": question, | |
| "answer": response, | |
| "is_valid": is_valid, | |
| "validation_message": validation_message, | |
| "knowledge_used": { | |
| "matched_terms": matched_terms, | |
| "relation_chains": relation_chains, | |
| "embedding_neighbors": embedding_neighbors | |
| } | |
| } | |
| def evaluate_medqa(self, test_file: str) -> Dict: | |
| """Evaluate the system on MedQA test set.""" | |
| results = [] | |
| correct = 0 | |
| total = 0 | |
| with open(test_file, 'r') as f: | |
| for line in f: | |
| data = json.loads(line) | |
| question = data['question'] | |
| true_answer = data['options'][data['answer_idx']] | |
| result = self.process_question(question) | |
| predicted_answer = result['answer'] | |
| # Simple exact match accuracy | |
| if predicted_answer.strip() == true_answer.strip(): | |
| correct += 1 | |
| total += 1 | |
| results.append({ | |
| "question": question, | |
| "true_answer": true_answer, | |
| "predicted_answer": predicted_answer, | |
| "is_correct": predicted_answer.strip() == true_answer.strip(), | |
| "validation_info": { | |
| "is_valid": result['is_valid'], | |
| "validation_message": result['validation_message'] | |
| } | |
| }) | |
| return { | |
| "accuracy": correct / total, | |
| "total_questions": total, | |
| "detailed_results": results | |
| } | |
| def simple_validate_response( | |
| client, | |
| response: str, | |
| matched_terms: List[Dict[str, any]], | |
| relation_chains: List[List[str]], | |
| embedding_neighbors: Dict[str, List[str]] | |
| ) -> Tuple[bool, str]: | |
| """Simple sanity check for response against knowledge.""" | |
| validation_prompt = f"""Given this medical knowledge: | |
| Medical Terms: {', '.join(term['matched'] for term in matched_terms)} | |
| Relationship Chains: | |
| {chr(10).join([' -> '.join(chain) for chain in relation_chains])} | |
| And this response: | |
| {response} | |
| Does the response contain any information that conflicts with or contradicts the provided medical knowledge? Reply with just "VALID" if there are no conflicts, or explain the specific conflict if you find one.""" | |
| validation_result = client.text_generation( | |
| validation_prompt, | |
| max_new_tokens=256, | |
| temperature=0.3 | |
| ).strip() | |
| is_valid = validation_result == "VALID" | |
| return is_valid, "" if is_valid else validation_result | |
| import gradio as gr | |
| from typing import List, Dict | |
| import json | |
| def summarize_answer_relationships( | |
| client, | |
| answer: str, | |
| relation_chains: List[List[str]] | |
| ) -> str: | |
| """Generate a summary of how the answer relates to the knowledge chains.""" | |
| summary_prompt = f"""Based on these relationship chains from our medical knowledge graph: | |
| {chr(10).join(['→ ' + ' -> '.join(chain) for chain in relation_chains])} | |
| And this answer: | |
| {answer} | |
| Provide a brief (2-3 sentences) summary of how the answer uses or relates to these medical relationships. Focus only on the clear connections between the answer and the relationship chains shown above.""" | |
| return client.text_generation( | |
| summary_prompt, | |
| max_new_tokens=200, | |
| temperature=0.3 | |
| ).strip() | |
| def create_chat_demo(qa_system: MedicalKGQASystem) -> gr.Interface: | |
| """Create a Gradio chat interface for the Medical QA system.""" | |
| def format_knowledge( | |
| knowledge: Dict, | |
| answer: str, | |
| client | |
| ) -> tuple[str, str, str]: | |
| """Format knowledge into three sections: terms, relationships, and summary.""" | |
| # Format relation chains | |
| chains_output = [] | |
| if knowledge.get('relation_chains'): | |
| chains_output.append("🔗 Key Medical Relationships:") | |
| chains = knowledge['relation_chains'] | |
| chains.reverse() | |
| for chain in chains: | |
| chains_output.append(f"• {' → '.join(chain)}") | |
| # Generate relationship summary | |
| if knowledge.get('relation_chains') and answer: | |
| summary = summarize_answer_relationships( | |
| client, | |
| answer, | |
| knowledge['relation_chains'] | |
| ) | |
| else: | |
| summary = "No relationship chains available for summary." | |
| terms_output = [] | |
| if knowledge.get('matched_terms'): | |
| terms_output.append("📌 Matched Terms:") | |
| for term in knowledge['matched_terms']: | |
| terms_output.append(f"• {term['original']} → {term['matched']}") | |
| return ( | |
| "\n".join(terms_output), | |
| "\n".join(chains_output), | |
| summary | |
| ) | |
| def process_chat( | |
| message: str, | |
| history: List[List[str]] | |
| ) -> tuple[List[List[str]], str, str, str, str]: | |
| """Process chat messages and return response with knowledge details.""" | |
| # Process the question | |
| result = qa_system.process_question(message) | |
| # Validate response | |
| is_valid, validation_msg = simple_validate_response( | |
| qa_system.client, | |
| result['answer'], | |
| result['knowledge_used']['matched_terms'], | |
| result['knowledge_used']['relation_chains'], | |
| result['knowledge_used']['embedding_neighbors'] | |
| ) | |
| # Format the response | |
| answer = result['answer'] | |
| if not is_valid: | |
| answer += f"\n\n⚠️ Potential Conflict: {validation_msg}" | |
| # Update chat history | |
| history.append([message, answer]) | |
| # Format knowledge sections | |
| terms, chains, summary = format_knowledge( | |
| result['knowledge_used'], | |
| answer, | |
| qa_system.client | |
| ) | |
| return history, terms, chains, summary, "✅ Valid" if is_valid else f"⚠️ {validation_msg}" | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Medical QA System") as demo: | |
| gr.Markdown(""" | |
| # 🏥 Medical Knowledge Graph QA | |
| Ask medical questions and get knowledge-grounded answers | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Chat History", | |
| height=500 | |
| ) | |
| msg = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a medical question...", | |
| lines=2 | |
| ) | |
| clear = gr.Button("Clear Chat") | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Knowledge Used", open=True): | |
| chains_box = gr.Textbox( | |
| label="Relationship Chains", | |
| lines=6, | |
| ) | |
| summary_box = gr.Textbox( | |
| label="💡 Knowledge Summary", | |
| lines=4, | |
| ) | |
| terms_box = gr.Textbox( | |
| label="Medical Terms", | |
| lines=4, | |
| ) | |
| validation_box = gr.Textbox( | |
| label="🔍 Validation Status", | |
| lines=2, | |
| ) | |
| # Handle message submission | |
| msg.submit( | |
| process_chat, | |
| [msg, chatbot], | |
| [chatbot, terms_box, chains_box, summary_box, validation_box] | |
| ) | |
| # Handle clear button | |
| clear.click( | |
| lambda: ([], "", "", "", ""), | |
| None, | |
| [chatbot, terms_box, chains_box, summary_box, validation_box] | |
| ) | |
| return demo | |
| import os | |
| # Usage example: | |
| if __name__ == "__main__": | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| system = MedicalKGQASystem( | |
| graph_path="new_homo_hg_hms.pt", | |
| node_info_path="new_node_map_df.csv", | |
| embeddings_path="full_h_embed_hms.pkl", | |
| hf_token=os.environ.get("HF_KEY", None), | |
| fuzzy_threshold=100, | |
| device=device | |
| ) | |
| # # Process single question | |
| # question = "What are the potential side effects of metformin?" | |
| # result = system.process_question(question) | |
| # print(f"Answer: {result['answer']}") | |
| # print(f"Validation: {result['validation_message']}") | |
| # Evaluate on test set | |
| # eval_results = system.evaluate_medqa("test.jsonl") | |
| # print(f"Overall accuracy: {eval_results['accuracy']:.2f}") | |
| # Create and launch the demo | |
| demo = create_chat_demo(system) | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |