Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Embedding Model Evaluator for Medical Content | |
| Tests different free embedding models to find the best for maternal health guidelines | |
| """ | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Tuple | |
| import logging | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.cluster import KMeans | |
| from sklearn.decomposition import PCA | |
| import matplotlib.pyplot as plt | |
| import time | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class MedicalEmbeddingEvaluator: | |
| """Evaluates different embedding models for medical content quality""" | |
| def __init__(self, chunks_dir: Path = Path("comprehensive_chunks")): | |
| self.chunks_dir = chunks_dir | |
| self.medical_chunks = [] | |
| self.evaluation_results = {} | |
| # Free embedding models to test | |
| self.embedding_models = { | |
| 'all-MiniLM-L6-v2': 'sentence-transformers/all-MiniLM-L6-v2', | |
| 'all-mpnet-base-v2': 'sentence-transformers/all-mpnet-base-v2', | |
| 'all-MiniLM-L12-v2': 'sentence-transformers/all-MiniLM-L12-v2', | |
| 'multi-qa-MiniLM-L6-cos-v1': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1', | |
| 'all-distilroberta-v1': 'sentence-transformers/all-distilroberta-v1' | |
| } | |
| # Medical test queries for evaluation | |
| self.test_queries = [ | |
| "What is the recommended dosage of magnesium sulfate for preeclampsia?", | |
| "How to manage postpartum hemorrhage in emergency situations?", | |
| "Normal ranges for fetal heart rate during labor", | |
| "Contraindications for vaginal delivery in breech presentation", | |
| "Signs and symptoms of puerperal sepsis", | |
| "Management of gestational diabetes during pregnancy", | |
| "Emergency cesarean section indications", | |
| "Postpartum care guidelines for mother and baby", | |
| "RhESUS incompatibility management protocol", | |
| "Antepartum monitoring guidelines for high-risk pregnancy" | |
| ] | |
| def load_medical_chunks(self) -> List[Dict]: | |
| """Load medical chunks from comprehensive chunking results""" | |
| logger.info("Loading medical chunks for embedding evaluation...") | |
| langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" | |
| if not langchain_file.exists(): | |
| raise FileNotFoundError(f"LangChain documents not found: {langchain_file}") | |
| with open(langchain_file) as f: | |
| chunks_data = json.load(f) | |
| # Filter and prepare chunks for evaluation | |
| medical_chunks = [] | |
| for chunk in chunks_data: | |
| content = chunk['page_content'] | |
| metadata = chunk['metadata'] | |
| # Skip very short chunks | |
| if len(content.strip()) < 100: | |
| continue | |
| medical_chunks.append({ | |
| 'content': content, | |
| 'chunk_type': metadata.get('chunk_type', 'text'), | |
| 'clinical_importance': metadata.get('clinical_importance', 0.5), | |
| 'source': metadata.get('source', ''), | |
| 'has_dosage_info': metadata.get('has_dosage_info', False), | |
| 'is_maternal_specific': metadata.get('is_maternal_specific', False), | |
| 'has_clinical_protocols': metadata.get('has_clinical_protocols', False) | |
| }) | |
| logger.info(f"Loaded {len(medical_chunks)} medical chunks for evaluation") | |
| return medical_chunks | |
| def evaluate_embedding_model(self, model_name: str, model_path: str) -> Dict[str, Any]: | |
| """Evaluate a single embedding model""" | |
| logger.info(f"Evaluating embedding model: {model_name}") | |
| try: | |
| # Load model | |
| start_time = time.time() | |
| model = SentenceTransformer(model_path) | |
| load_time = time.time() - start_time | |
| # Sample chunks for evaluation (use subset for speed) | |
| sample_chunks = self.medical_chunks[:100] # Use first 100 chunks | |
| chunk_texts = [chunk['content'] for chunk in sample_chunks] | |
| # Generate embeddings for chunks | |
| logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...") | |
| start_time = time.time() | |
| chunk_embeddings = model.encode(chunk_texts, show_progress_bar=True) | |
| chunk_embed_time = time.time() - start_time | |
| # Generate embeddings for test queries | |
| start_time = time.time() | |
| query_embeddings = model.encode(self.test_queries) | |
| query_embed_time = time.time() - start_time | |
| # Evaluation metrics | |
| results = { | |
| 'model_name': model_name, | |
| 'model_path': model_path, | |
| 'load_time': load_time, | |
| 'chunk_embed_time': chunk_embed_time, | |
| 'query_embed_time': query_embed_time, | |
| 'embedding_dimension': chunk_embeddings.shape[1], | |
| 'chunks_processed': len(chunk_texts), | |
| 'queries_processed': len(self.test_queries) | |
| } | |
| # Test semantic search quality | |
| search_results = self._evaluate_search_quality( | |
| query_embeddings, chunk_embeddings, sample_chunks | |
| ) | |
| results.update(search_results) | |
| # Test clustering quality | |
| cluster_results = self._evaluate_clustering_quality( | |
| chunk_embeddings, sample_chunks | |
| ) | |
| results.update(cluster_results) | |
| # Calculate overall score | |
| results['overall_score'] = self._calculate_overall_score(results) | |
| logger.info(f"β {model_name} evaluation complete - Overall Score: {results['overall_score']:.3f}") | |
| return results | |
| except Exception as e: | |
| logger.error(f"β Failed to evaluate {model_name}: {e}") | |
| return { | |
| 'model_name': model_name, | |
| 'model_path': model_path, | |
| 'error': str(e), | |
| 'overall_score': 0.0 | |
| } | |
| def _evaluate_search_quality(self, query_embeddings: np.ndarray, | |
| chunk_embeddings: np.ndarray, | |
| chunks: List[Dict]) -> Dict[str, float]: | |
| """Evaluate semantic search quality""" | |
| # Calculate similarities between queries and chunks | |
| similarities = cosine_similarity(query_embeddings, chunk_embeddings) | |
| search_metrics = { | |
| 'avg_max_similarity': 0.0, | |
| 'medical_content_precision': 0.0, | |
| 'dosage_query_accuracy': 0.0, | |
| 'emergency_query_accuracy': 0.0 | |
| } | |
| total_queries = len(self.test_queries) | |
| for i, query in enumerate(self.test_queries): | |
| query_similarities = similarities[i] | |
| top_indices = np.argsort(query_similarities)[::-1][:5] # Top 5 results | |
| # Max similarity for this query | |
| max_sim = np.max(query_similarities) | |
| search_metrics['avg_max_similarity'] += max_sim | |
| # Check if top results contain relevant medical content | |
| top_chunks = [chunks[idx] for idx in top_indices] | |
| medical_relevant = sum(1 for chunk in top_chunks | |
| if chunk['clinical_importance'] > 0.7) | |
| search_metrics['medical_content_precision'] += medical_relevant / 5 | |
| # Specific query type accuracy | |
| if 'dosage' in query.lower() or 'dose' in query.lower(): | |
| dosage_relevant = sum(1 for chunk in top_chunks | |
| if chunk['has_dosage_info']) | |
| search_metrics['dosage_query_accuracy'] += dosage_relevant / 5 | |
| if 'emergency' in query.lower() or 'urgent' in query.lower(): | |
| emergency_relevant = sum(1 for chunk in top_chunks | |
| if chunk['chunk_type'] == 'emergency') | |
| search_metrics['emergency_query_accuracy'] += emergency_relevant / 5 | |
| # Average the metrics | |
| for key in search_metrics: | |
| search_metrics[key] /= total_queries | |
| return search_metrics | |
| def _evaluate_clustering_quality(self, embeddings: np.ndarray, | |
| chunks: List[Dict]) -> Dict[str, float]: | |
| """Evaluate how well embeddings cluster similar medical content""" | |
| # Perform clustering | |
| n_clusters = min(8, len(chunks) // 10) # Reasonable number of clusters | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42) | |
| cluster_labels = kmeans.fit_predict(embeddings) | |
| # Calculate cluster purity based on chunk types | |
| cluster_metrics = { | |
| 'cluster_purity': 0.0, | |
| 'dosage_cluster_coherence': 0.0, | |
| 'maternal_cluster_coherence': 0.0 | |
| } | |
| # Calculate cluster purity | |
| total_items = len(chunks) | |
| for cluster_id in range(n_clusters): | |
| cluster_indices = np.where(cluster_labels == cluster_id)[0] | |
| if len(cluster_indices) == 0: | |
| continue | |
| cluster_chunks = [chunks[i] for i in cluster_indices] | |
| # Find dominant chunk type in this cluster | |
| chunk_types = [chunk['chunk_type'] for chunk in cluster_chunks] | |
| if chunk_types: | |
| dominant_type = max(set(chunk_types), key=chunk_types.count) | |
| purity = chunk_types.count(dominant_type) / len(chunk_types) | |
| cluster_metrics['cluster_purity'] += purity * len(cluster_indices) / total_items | |
| # Check dosage content clustering | |
| dosage_chunks = [chunk for chunk in cluster_chunks if chunk['has_dosage_info']] | |
| if len(cluster_chunks) > 0: | |
| dosage_ratio = len(dosage_chunks) / len(cluster_chunks) | |
| if dosage_ratio > 0.5: # If majority are dosage chunks | |
| cluster_metrics['dosage_cluster_coherence'] += dosage_ratio | |
| # Check maternal content clustering | |
| maternal_chunks = [chunk for chunk in cluster_chunks if chunk['is_maternal_specific']] | |
| if len(cluster_chunks) > 0: | |
| maternal_ratio = len(maternal_chunks) / len(cluster_chunks) | |
| if maternal_ratio > 0.5: # If majority are maternal chunks | |
| cluster_metrics['maternal_cluster_coherence'] += maternal_ratio | |
| return cluster_metrics | |
| def _calculate_overall_score(self, results: Dict[str, Any]) -> float: | |
| """Calculate overall score for the embedding model""" | |
| if 'error' in results: | |
| return 0.0 | |
| # Weighted scoring components | |
| weights = { | |
| 'search_quality': 0.4, | |
| 'clustering_quality': 0.2, | |
| 'speed': 0.2, | |
| 'medical_relevance': 0.2 | |
| } | |
| # Search quality score (0-1) | |
| search_score = ( | |
| results.get('avg_max_similarity', 0) * 0.4 + | |
| results.get('medical_content_precision', 0) * 0.3 + | |
| results.get('dosage_query_accuracy', 0) * 0.15 + | |
| results.get('emergency_query_accuracy', 0) * 0.15 | |
| ) | |
| # Clustering quality score (0-1) | |
| cluster_score = ( | |
| results.get('cluster_purity', 0) * 0.5 + | |
| results.get('dosage_cluster_coherence', 0) * 0.25 + | |
| results.get('maternal_cluster_coherence', 0) * 0.25 | |
| ) | |
| # Speed score (inverse of time, normalized) | |
| total_time = results.get('chunk_embed_time', 1) + results.get('query_embed_time', 1) | |
| speed_score = max(0, 1 - (total_time / 100)) # Normalize to 0-1 | |
| # Medical relevance (based on search accuracy for medical queries) | |
| medical_score = ( | |
| results.get('medical_content_precision', 0) * 0.6 + | |
| results.get('dosage_query_accuracy', 0) * 0.4 | |
| ) | |
| # Calculate weighted overall score | |
| overall = ( | |
| search_score * weights['search_quality'] + | |
| cluster_score * weights['clustering_quality'] + | |
| speed_score * weights['speed'] + | |
| medical_score * weights['medical_relevance'] | |
| ) | |
| return min(1.0, max(0.0, overall)) | |
| def run_comprehensive_evaluation(self) -> Dict[str, Any]: | |
| """Run comprehensive evaluation of all embedding models""" | |
| logger.info("Starting comprehensive embedding model evaluation...") | |
| # Load medical chunks | |
| self.medical_chunks = self.load_medical_chunks() | |
| if len(self.medical_chunks) == 0: | |
| raise ValueError("No medical chunks loaded for evaluation") | |
| # Evaluate each model | |
| results = {} | |
| for model_name, model_path in self.embedding_models.items(): | |
| logger.info(f"\nπ Evaluating: {model_name}") | |
| results[model_name] = self.evaluate_embedding_model(model_name, model_path) | |
| # Generate summary report | |
| summary = self._generate_evaluation_summary(results) | |
| # Save results | |
| output_file = Path("src/embedding_evaluation_results.json") | |
| with open(output_file, 'w') as f: | |
| json.dump({ | |
| 'evaluation_summary': summary, | |
| 'detailed_results': results, | |
| 'test_queries': self.test_queries, | |
| 'chunks_evaluated': len(self.medical_chunks) | |
| }, f, indent=2) | |
| logger.info(f"π Evaluation results saved to: {output_file}") | |
| return summary | |
| def _generate_evaluation_summary(self, results: Dict[str, Any]) -> Dict[str, Any]: | |
| """Generate evaluation summary with recommendations""" | |
| valid_results = {k: v for k, v in results.items() if 'error' not in v} | |
| if not valid_results: | |
| return {'error': 'No models evaluated successfully'} | |
| # Find best model | |
| best_model = max(valid_results.items(), key=lambda x: x[1]['overall_score']) | |
| # Calculate averages | |
| avg_scores = {} | |
| for metric in ['overall_score', 'avg_max_similarity', 'medical_content_precision']: | |
| scores = [r.get(metric, 0) for r in valid_results.values()] | |
| avg_scores[f'avg_{metric}'] = sum(scores) / len(scores) if scores else 0 | |
| summary = { | |
| 'best_model': { | |
| 'name': best_model[0], | |
| 'path': best_model[1]['model_path'], | |
| 'score': best_model[1]['overall_score'], | |
| 'strengths': [] | |
| }, | |
| 'model_rankings': sorted( | |
| [(name, res['overall_score']) for name, res in valid_results.items()], | |
| key=lambda x: x[1], reverse=True | |
| ), | |
| 'evaluation_metrics': avg_scores, | |
| 'recommendation': '', | |
| 'models_tested': len(results), | |
| 'successful_evaluations': len(valid_results) | |
| } | |
| # Add strengths and recommendation | |
| best_result = best_model[1] | |
| strengths = [] | |
| if best_result.get('medical_content_precision', 0) > 0.7: | |
| strengths.append("High medical content precision") | |
| if best_result.get('dosage_query_accuracy', 0) > 0.6: | |
| strengths.append("Good dosage information retrieval") | |
| if best_result.get('cluster_purity', 0) > 0.6: | |
| strengths.append("Effective content clustering") | |
| if best_result.get('chunk_embed_time', 100) < 30: | |
| strengths.append("Fast embedding generation") | |
| summary['best_model']['strengths'] = strengths | |
| summary['recommendation'] = ( | |
| f"Recommended model: {best_model[0]} with overall score {best_result['overall_score']:.3f}. " | |
| f"This model shows {', '.join(strengths)} and is well-suited for maternal health content." | |
| ) | |
| return summary | |
| def main(): | |
| """Main evaluation function""" | |
| evaluator = MedicalEmbeddingEvaluator() | |
| try: | |
| summary = evaluator.run_comprehensive_evaluation() | |
| # Print summary | |
| logger.info("=" * 80) | |
| logger.info("EMBEDDING MODEL EVALUATION COMPLETE!") | |
| logger.info("=" * 80) | |
| logger.info(f"π Best Model: {summary['best_model']['name']}") | |
| logger.info(f"π Overall Score: {summary['best_model']['score']:.3f}") | |
| logger.info(f"πͺ Strengths: {', '.join(summary['best_model']['strengths'])}") | |
| logger.info(f"π Recommendation: {summary['recommendation']}") | |
| logger.info("\nπ Model Rankings:") | |
| for i, (model, score) in enumerate(summary['model_rankings'], 1): | |
| logger.info(f"{i}. {model}: {score:.3f}") | |
| logger.info("=" * 80) | |
| return summary | |
| except Exception as e: | |
| logger.error(f"β Evaluation failed: {e}") | |
| return None | |
| if __name__ == "__main__": | |
| main() |