Spaces:
Running
Running
| # chat/comparator.py | |
| """ | |
| Advanced cross-paper comparison and methodology analysis | |
| Compares techniques, results, and approaches across multiple studies | |
| """ | |
| from typing import List, Dict, Any, Tuple | |
| from llm.llm_provider import GrokLLM | |
| from llm.prompt_templates import MedicalResearchPrompts, ResponseFormatter | |
| import re | |
| class CrossPaperComparator: | |
| """ | |
| Advanced comparator that analyzes and compares research across multiple papers | |
| Focuses on methodologies, results, and performance metrics | |
| """ | |
| def __init__(self, llm=None): | |
| self.llm = llm or GrokLLM(model="model") # Use shared LLM | |
| self.prompts = MedicalResearchPrompts() | |
| self.formatter = ResponseFormatter() | |
| def compare_methods(self, papers: List[Dict], method1: str, method2: str, domain: str) -> Dict[str, Any]: | |
| """ | |
| Compare two methods across multiple research papers | |
| """ | |
| print(f"π Comparing {method1} vs {method2} across {len(papers)} papers") | |
| # Filter papers that mention either method | |
| relevant_papers = self._filter_relevant_papers(papers, [method1, method2]) | |
| if len(relevant_papers) < 2: | |
| return self._create_insufficient_data_response(method1, method2, domain, len(relevant_papers)) | |
| try: | |
| # Generate detailed comparison | |
| comparison_prompt = self.prompts.cross_paper_comparison(relevant_papers, method1, method2, domain) | |
| response = self.llm.generate( | |
| comparison_prompt, | |
| system_message=self.prompts.SYSTEM_MESSAGES["methodology_expert"], | |
| temperature=0.1, | |
| max_tokens=2000 | |
| ) | |
| # Extract performance metrics and structured data | |
| performance_data = self._extract_performance_metrics(relevant_papers, method1, method2) | |
| trend_analysis = self._analyze_trends(relevant_papers, method1, method2) | |
| return { | |
| "detailed_comparison": response, | |
| "performance_metrics": performance_data, | |
| "trend_analysis": trend_analysis, | |
| "method1": method1, | |
| "method2": method2, | |
| "domain": domain, | |
| "papers_analyzed": len(relevant_papers), | |
| "relevant_papers": [self.formatter.format_citation(paper, i + 1) for i, paper in | |
| enumerate(relevant_papers)] | |
| } | |
| except Exception as e: | |
| print(f"β Comparison error: {e}") | |
| return self._create_fallback_comparison(relevant_papers, method1, method2, domain) | |
| def _filter_relevant_papers(self, papers: List[Dict], methods: List[str]) -> List[Dict]: | |
| """Filter papers that mention the methods of interest""" | |
| relevant_papers = [] | |
| for paper in papers: | |
| abstract = paper.get('abstract', '').lower() | |
| title = paper.get('title', '').lower() | |
| # Check if paper mentions any of the methods | |
| for method in methods: | |
| method_terms = self._get_method_variations(method) | |
| if any(term in abstract or term in title for term in method_terms): | |
| relevant_papers.append(paper) | |
| break | |
| return relevant_papers | |
| def _get_method_variations(self, method: str) -> List[str]: | |
| """Get common variations of method names""" | |
| method = method.lower() | |
| variations = [method] | |
| # Common method variations | |
| method_variations = { | |
| 'cnn': ['cnn', 'convolutional neural network', 'convolutional network'], | |
| 'transformer': ['transformer', 'attention mechanism', 'self-attention'], | |
| 'random forest': ['random forest', 'rf classifier'], | |
| 'svm': ['svm', 'support vector machine'], | |
| 'knn': ['knn', 'k-nearest neighbor'], | |
| 'logistic regression': ['logistic regression', 'logit model'] | |
| } | |
| if method in method_variations: | |
| variations.extend(method_variations[method]) | |
| return variations | |
| def _extract_performance_metrics(self, papers: List[Dict], method1: str, method2: str) -> Dict[str, Any]: | |
| """Extract performance metrics from papers""" | |
| metrics = { | |
| method1: {"papers_count": 0, "performance_mentions": []}, | |
| method2: {"papers_count": 0, "performance_mentions": []}, | |
| "comparison_mentions": [] | |
| } | |
| performance_patterns = [ | |
| r'(\d+\.?\d*)%', # Percentage metrics | |
| r'accuracy of (\d+\.?\d*)', | |
| r'precision of (\d+\.?\d*)', | |
| r'recall of (\d+\.?\d*)', | |
| r'f1 score of (\d+\.?\d*)', | |
| r'auc of (\d+\.?\d*)' | |
| ] | |
| for paper in papers: | |
| abstract = paper.get('abstract', '').lower() | |
| title = paper.get('title', '').lower() | |
| # Count method mentions | |
| method1_terms = self._get_method_variations(method1) | |
| method2_terms = self._get_method_variations(method2) | |
| method1_mentioned = any(term in abstract or term in title for term in method1_terms) | |
| method2_mentioned = any(term in abstract or term in title for term in method2_terms) | |
| if method1_mentioned: | |
| metrics[method1]["papers_count"] += 1 | |
| if method2_mentioned: | |
| metrics[method2]["papers_count"] += 1 | |
| # Extract performance numbers | |
| for pattern in performance_patterns: | |
| matches = re.findall(pattern, abstract) | |
| for match in matches: | |
| try: | |
| value = float(match) | |
| if method1_mentioned: | |
| metrics[method1]["performance_mentions"].append(value) | |
| if method2_mentioned: | |
| metrics[method2]["performance_mentions"].append(value) | |
| except ValueError: | |
| continue | |
| # Look for direct comparisons | |
| comparison_terms = ['compared to', 'versus', 'vs', 'outperform', 'better than', 'worse than'] | |
| if any(term in abstract for term in comparison_terms) and (method1_mentioned or method2_mentioned): | |
| metrics["comparison_mentions"].append(paper.get('title', 'Unknown')) | |
| # Calculate average performance if we have data | |
| for method in [method1, method2]: | |
| mentions = metrics[method]["performance_mentions"] | |
| if mentions: | |
| metrics[method]["average_performance"] = sum(mentions) / len(mentions) | |
| metrics[method]["max_performance"] = max(mentions) | |
| metrics[method]["min_performance"] = min(mentions) | |
| return metrics | |
| def _analyze_trends(self, papers: List[Dict], method1: str, method2: str) -> Dict[str, Any]: | |
| """Analyze publication trends for each method""" | |
| trends = { | |
| method1: {"recent_papers": 0, "total_papers": 0}, | |
| method2: {"recent_papers": 0, "total_papers": 0}, | |
| "trend_direction": "neutral" | |
| } | |
| for paper in papers: | |
| # Simple date-based trend analysis | |
| date = paper.get('publication_date', '') | |
| abstract = paper.get('abstract', '').lower() | |
| method1_mentioned = any(term in abstract for term in self._get_method_variations(method1)) | |
| method2_mentioned = any(term in abstract for term in self._get_method_variations(method2)) | |
| if method1_mentioned: | |
| trends[method1]["total_papers"] += 1 | |
| if '2024' in date or '2023' in date: | |
| trends[method1]["recent_papers"] += 1 | |
| if method2_mentioned: | |
| trends[method2]["total_papers"] += 1 | |
| if '2024' in date or '2023' in date: | |
| trends[method2]["recent_papers"] += 1 | |
| # Determine trend direction | |
| method1_recent_ratio = trends[method1]["recent_papers"] / max(1, trends[method1]["total_papers"]) | |
| method2_recent_ratio = trends[method2]["recent_papers"] / max(1, trends[method2]["total_papers"]) | |
| if method1_recent_ratio > method2_recent_ratio + 0.2: | |
| trends["trend_direction"] = f"{method1} gaining" | |
| elif method2_recent_ratio > method1_recent_ratio + 0.2: | |
| trends["trend_direction"] = f"{method2} gaining" | |
| else: | |
| trends["trend_direction"] = "both stable" | |
| return trends | |
| def _create_insufficient_data_response(self, method1: str, method2: str, domain: str, relevant_count: int) -> Dict[ | |
| str, Any]: | |
| """Create response when insufficient data is available""" | |
| return { | |
| "detailed_comparison": f"Insufficient data for comparison. Only {relevant_count} papers mention {method1} or {method2} in the {domain} domain.", | |
| "performance_metrics": {}, | |
| "trend_analysis": {}, | |
| "method1": method1, | |
| "method2": method2, | |
| "domain": domain, | |
| "papers_analyzed": relevant_count, | |
| "relevant_papers": [], | |
| "insufficient_data": True | |
| } | |
| def _create_fallback_comparison(self, papers: List[Dict], method1: str, method2: str, domain: str) -> Dict[ | |
| str, Any]: | |
| """Create basic comparison when LLM fails""" | |
| performance_data = self._extract_performance_metrics(papers, method1, method2) | |
| trend_analysis = self._analyze_trends(papers, method1, method2) | |
| basic_comparison = f""" | |
| Basic Comparison: {method1} vs {method2} in {domain} | |
| Papers Analyzed: {len(papers)} | |
| {method1} mentioned in: {performance_data[method1]['papers_count']} papers | |
| {method2} mentioned in: {performance_data[method2]['papers_count']} papers | |
| Trend: {trend_analysis['trend_direction']} | |
| Note: Detailed AI comparison unavailable. Consider refining your search terms. | |
| """ | |
| return { | |
| "detailed_comparison": basic_comparison, | |
| "performance_metrics": performance_data, | |
| "trend_analysis": trend_analysis, | |
| "method1": method1, | |
| "method2": method2, | |
| "domain": domain, | |
| "papers_analyzed": len(papers), | |
| "relevant_papers": [self.formatter.format_citation(paper, i + 1) for i, paper in enumerate(papers)], | |
| "fallback_used": True | |
| } | |
| def generate_comparison_table(self, comparison_data: Dict[str, Any]) -> str: | |
| """Generate a structured comparison table""" | |
| method1 = comparison_data['method1'] | |
| method2 = comparison_data['method2'] | |
| metrics = comparison_data['performance_metrics'] | |
| trends = comparison_data['trend_analysis'] | |
| table = f"**Comparison: {method1} vs {method2}**\n\n" | |
| table += "| Metric | {method1} | {method2} |\n".format(method1=method1, method2=method2) | |
| table += "|--------|-----------|-----------|\n" | |
| # Papers count | |
| table += "| Papers Mentioned | {count1} | {count2} |\n".format( | |
| count1=metrics[method1]['papers_count'], | |
| count2=metrics[method2]['papers_count'] | |
| ) | |
| # Recent papers | |
| table += "| Recent Papers (2023-2024) | {recent1} | {recent2} |\n".format( | |
| recent1=trends[method1]['recent_papers'], | |
| recent2=trends[method2]['recent_papers'] | |
| ) | |
| # Average performance if available | |
| if 'average_performance' in metrics[method1] and 'average_performance' in metrics[method2]: | |
| table += "| Avg Performance | {avg1:.1f}% | {avg2:.1f}% |\n".format( | |
| avg1=metrics[method1]['average_performance'], | |
| avg2=metrics[method2]['average_performance'] | |
| ) | |
| table += f"\n**Trend:** {trends['trend_direction']}\n" | |
| return table | |
| # Quick test | |
| def test_comparator(): | |
| """Test the cross-paper comparator""" | |
| print("π§ͺ Testing Cross-Paper Comparator") | |
| print("=" * 50) | |
| test_papers = [ | |
| { | |
| 'title': 'CNN vs Transformer for Medical Images', | |
| 'authors': ['Smith J', 'Lee K'], | |
| 'abstract': 'We compare CNN and Transformer architectures for medical image classification. CNNs achieve 92% accuracy while Transformers reach 94% but require more data.', | |
| 'source': 'Medical Image Analysis', | |
| 'domain': 'medical_imaging', | |
| 'publication_date': '2024-01-15' | |
| }, | |
| { | |
| 'title': 'Efficient Transformers in Radiology', | |
| 'authors': ['Chen R', 'Wang L'], | |
| 'abstract': 'This paper introduces efficient transformer variants for radiology applications. Our method maintains 93% accuracy with 50% fewer parameters compared to standard transformers.', | |
| 'source': 'IEEE Transactions', | |
| 'domain': 'medical_imaging', | |
| 'publication_date': '2024-03-10' | |
| } | |
| ] | |
| comparator = CrossPaperComparator() | |
| try: | |
| comparison = comparator.compare_methods( | |
| test_papers, | |
| "CNN", | |
| "Transformer", | |
| "medical_imaging" | |
| ) | |
| print(f"β Comparison generated successfully") | |
| print(f"π Papers analyzed: {comparison['papers_analyzed']}") | |
| print(f"π Trend: {comparison['trend_analysis']['trend_direction']}") | |
| table = comparator.generate_comparison_table(comparison) | |
| print(f"\nπ Comparison Table:\n{table}") | |
| except Exception as e: | |
| print(f"β Comparison test failed: {e}") | |
| if __name__ == "__main__": | |
| test_comparator() |