MedSearchPro / chat /comparator.py
paulhemb's picture
Initial Backend Deployment
1367957
# chat/comparator.py
"""
Advanced cross-paper comparison and methodology analysis
Compares techniques, results, and approaches across multiple studies
"""
from typing import List, Dict, Any, Tuple
from llm.llm_provider import GrokLLM
from llm.prompt_templates import MedicalResearchPrompts, ResponseFormatter
import re
class CrossPaperComparator:
"""
Advanced comparator that analyzes and compares research across multiple papers
Focuses on methodologies, results, and performance metrics
"""
def __init__(self, llm=None):
self.llm = llm or GrokLLM(model="model") # Use shared LLM
self.prompts = MedicalResearchPrompts()
self.formatter = ResponseFormatter()
def compare_methods(self, papers: List[Dict], method1: str, method2: str, domain: str) -> Dict[str, Any]:
"""
Compare two methods across multiple research papers
"""
print(f"πŸ” Comparing {method1} vs {method2} across {len(papers)} papers")
# Filter papers that mention either method
relevant_papers = self._filter_relevant_papers(papers, [method1, method2])
if len(relevant_papers) < 2:
return self._create_insufficient_data_response(method1, method2, domain, len(relevant_papers))
try:
# Generate detailed comparison
comparison_prompt = self.prompts.cross_paper_comparison(relevant_papers, method1, method2, domain)
response = self.llm.generate(
comparison_prompt,
system_message=self.prompts.SYSTEM_MESSAGES["methodology_expert"],
temperature=0.1,
max_tokens=2000
)
# Extract performance metrics and structured data
performance_data = self._extract_performance_metrics(relevant_papers, method1, method2)
trend_analysis = self._analyze_trends(relevant_papers, method1, method2)
return {
"detailed_comparison": response,
"performance_metrics": performance_data,
"trend_analysis": trend_analysis,
"method1": method1,
"method2": method2,
"domain": domain,
"papers_analyzed": len(relevant_papers),
"relevant_papers": [self.formatter.format_citation(paper, i + 1) for i, paper in
enumerate(relevant_papers)]
}
except Exception as e:
print(f"❌ Comparison error: {e}")
return self._create_fallback_comparison(relevant_papers, method1, method2, domain)
def _filter_relevant_papers(self, papers: List[Dict], methods: List[str]) -> List[Dict]:
"""Filter papers that mention the methods of interest"""
relevant_papers = []
for paper in papers:
abstract = paper.get('abstract', '').lower()
title = paper.get('title', '').lower()
# Check if paper mentions any of the methods
for method in methods:
method_terms = self._get_method_variations(method)
if any(term in abstract or term in title for term in method_terms):
relevant_papers.append(paper)
break
return relevant_papers
def _get_method_variations(self, method: str) -> List[str]:
"""Get common variations of method names"""
method = method.lower()
variations = [method]
# Common method variations
method_variations = {
'cnn': ['cnn', 'convolutional neural network', 'convolutional network'],
'transformer': ['transformer', 'attention mechanism', 'self-attention'],
'random forest': ['random forest', 'rf classifier'],
'svm': ['svm', 'support vector machine'],
'knn': ['knn', 'k-nearest neighbor'],
'logistic regression': ['logistic regression', 'logit model']
}
if method in method_variations:
variations.extend(method_variations[method])
return variations
def _extract_performance_metrics(self, papers: List[Dict], method1: str, method2: str) -> Dict[str, Any]:
"""Extract performance metrics from papers"""
metrics = {
method1: {"papers_count": 0, "performance_mentions": []},
method2: {"papers_count": 0, "performance_mentions": []},
"comparison_mentions": []
}
performance_patterns = [
r'(\d+\.?\d*)%', # Percentage metrics
r'accuracy of (\d+\.?\d*)',
r'precision of (\d+\.?\d*)',
r'recall of (\d+\.?\d*)',
r'f1 score of (\d+\.?\d*)',
r'auc of (\d+\.?\d*)'
]
for paper in papers:
abstract = paper.get('abstract', '').lower()
title = paper.get('title', '').lower()
# Count method mentions
method1_terms = self._get_method_variations(method1)
method2_terms = self._get_method_variations(method2)
method1_mentioned = any(term in abstract or term in title for term in method1_terms)
method2_mentioned = any(term in abstract or term in title for term in method2_terms)
if method1_mentioned:
metrics[method1]["papers_count"] += 1
if method2_mentioned:
metrics[method2]["papers_count"] += 1
# Extract performance numbers
for pattern in performance_patterns:
matches = re.findall(pattern, abstract)
for match in matches:
try:
value = float(match)
if method1_mentioned:
metrics[method1]["performance_mentions"].append(value)
if method2_mentioned:
metrics[method2]["performance_mentions"].append(value)
except ValueError:
continue
# Look for direct comparisons
comparison_terms = ['compared to', 'versus', 'vs', 'outperform', 'better than', 'worse than']
if any(term in abstract for term in comparison_terms) and (method1_mentioned or method2_mentioned):
metrics["comparison_mentions"].append(paper.get('title', 'Unknown'))
# Calculate average performance if we have data
for method in [method1, method2]:
mentions = metrics[method]["performance_mentions"]
if mentions:
metrics[method]["average_performance"] = sum(mentions) / len(mentions)
metrics[method]["max_performance"] = max(mentions)
metrics[method]["min_performance"] = min(mentions)
return metrics
def _analyze_trends(self, papers: List[Dict], method1: str, method2: str) -> Dict[str, Any]:
"""Analyze publication trends for each method"""
trends = {
method1: {"recent_papers": 0, "total_papers": 0},
method2: {"recent_papers": 0, "total_papers": 0},
"trend_direction": "neutral"
}
for paper in papers:
# Simple date-based trend analysis
date = paper.get('publication_date', '')
abstract = paper.get('abstract', '').lower()
method1_mentioned = any(term in abstract for term in self._get_method_variations(method1))
method2_mentioned = any(term in abstract for term in self._get_method_variations(method2))
if method1_mentioned:
trends[method1]["total_papers"] += 1
if '2024' in date or '2023' in date:
trends[method1]["recent_papers"] += 1
if method2_mentioned:
trends[method2]["total_papers"] += 1
if '2024' in date or '2023' in date:
trends[method2]["recent_papers"] += 1
# Determine trend direction
method1_recent_ratio = trends[method1]["recent_papers"] / max(1, trends[method1]["total_papers"])
method2_recent_ratio = trends[method2]["recent_papers"] / max(1, trends[method2]["total_papers"])
if method1_recent_ratio > method2_recent_ratio + 0.2:
trends["trend_direction"] = f"{method1} gaining"
elif method2_recent_ratio > method1_recent_ratio + 0.2:
trends["trend_direction"] = f"{method2} gaining"
else:
trends["trend_direction"] = "both stable"
return trends
def _create_insufficient_data_response(self, method1: str, method2: str, domain: str, relevant_count: int) -> Dict[
str, Any]:
"""Create response when insufficient data is available"""
return {
"detailed_comparison": f"Insufficient data for comparison. Only {relevant_count} papers mention {method1} or {method2} in the {domain} domain.",
"performance_metrics": {},
"trend_analysis": {},
"method1": method1,
"method2": method2,
"domain": domain,
"papers_analyzed": relevant_count,
"relevant_papers": [],
"insufficient_data": True
}
def _create_fallback_comparison(self, papers: List[Dict], method1: str, method2: str, domain: str) -> Dict[
str, Any]:
"""Create basic comparison when LLM fails"""
performance_data = self._extract_performance_metrics(papers, method1, method2)
trend_analysis = self._analyze_trends(papers, method1, method2)
basic_comparison = f"""
Basic Comparison: {method1} vs {method2} in {domain}
Papers Analyzed: {len(papers)}
{method1} mentioned in: {performance_data[method1]['papers_count']} papers
{method2} mentioned in: {performance_data[method2]['papers_count']} papers
Trend: {trend_analysis['trend_direction']}
Note: Detailed AI comparison unavailable. Consider refining your search terms.
"""
return {
"detailed_comparison": basic_comparison,
"performance_metrics": performance_data,
"trend_analysis": trend_analysis,
"method1": method1,
"method2": method2,
"domain": domain,
"papers_analyzed": len(papers),
"relevant_papers": [self.formatter.format_citation(paper, i + 1) for i, paper in enumerate(papers)],
"fallback_used": True
}
def generate_comparison_table(self, comparison_data: Dict[str, Any]) -> str:
"""Generate a structured comparison table"""
method1 = comparison_data['method1']
method2 = comparison_data['method2']
metrics = comparison_data['performance_metrics']
trends = comparison_data['trend_analysis']
table = f"**Comparison: {method1} vs {method2}**\n\n"
table += "| Metric | {method1} | {method2} |\n".format(method1=method1, method2=method2)
table += "|--------|-----------|-----------|\n"
# Papers count
table += "| Papers Mentioned | {count1} | {count2} |\n".format(
count1=metrics[method1]['papers_count'],
count2=metrics[method2]['papers_count']
)
# Recent papers
table += "| Recent Papers (2023-2024) | {recent1} | {recent2} |\n".format(
recent1=trends[method1]['recent_papers'],
recent2=trends[method2]['recent_papers']
)
# Average performance if available
if 'average_performance' in metrics[method1] and 'average_performance' in metrics[method2]:
table += "| Avg Performance | {avg1:.1f}% | {avg2:.1f}% |\n".format(
avg1=metrics[method1]['average_performance'],
avg2=metrics[method2]['average_performance']
)
table += f"\n**Trend:** {trends['trend_direction']}\n"
return table
# Quick test
def test_comparator():
"""Test the cross-paper comparator"""
print("πŸ§ͺ Testing Cross-Paper Comparator")
print("=" * 50)
test_papers = [
{
'title': 'CNN vs Transformer for Medical Images',
'authors': ['Smith J', 'Lee K'],
'abstract': 'We compare CNN and Transformer architectures for medical image classification. CNNs achieve 92% accuracy while Transformers reach 94% but require more data.',
'source': 'Medical Image Analysis',
'domain': 'medical_imaging',
'publication_date': '2024-01-15'
},
{
'title': 'Efficient Transformers in Radiology',
'authors': ['Chen R', 'Wang L'],
'abstract': 'This paper introduces efficient transformer variants for radiology applications. Our method maintains 93% accuracy with 50% fewer parameters compared to standard transformers.',
'source': 'IEEE Transactions',
'domain': 'medical_imaging',
'publication_date': '2024-03-10'
}
]
comparator = CrossPaperComparator()
try:
comparison = comparator.compare_methods(
test_papers,
"CNN",
"Transformer",
"medical_imaging"
)
print(f"βœ… Comparison generated successfully")
print(f"πŸ“Š Papers analyzed: {comparison['papers_analyzed']}")
print(f"πŸ“ˆ Trend: {comparison['trend_analysis']['trend_direction']}")
table = comparator.generate_comparison_table(comparison)
print(f"\nπŸ“‹ Comparison Table:\n{table}")
except Exception as e:
print(f"❌ Comparison test failed: {e}")
if __name__ == "__main__":
test_comparator()