Medica_DecisionSupportAI / schema_mapper.py
Rajan Sharma
Update schema_mapper.py
c2ce9a8 verified
raw
history blame
19.7 kB
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Dict, List, Any, Tuple, Optional, Set
import pandas as pd
from data_registry import DataRegistry
# Generic concept patterns that work across domains
UNIVERSAL_CONCEPT_PATTERNS = {
# Entity/grouping concepts
"facility": [r"\bfacilit(y|ies)\b", r"\bhospital\b", r"\bsite\b", r"\bcentre\b", r"\bcenter\b", r"\blocation\b", r"\bprovider\b"],
"organization": [r"\borganization\b", r"\bcompany\b", r"\bbusiness\b", r"\bfirm\b", r"\bentity\b"],
"department": [r"\bdepartment\b", r"\bdivision\b", r"\bunit\b", r"\bsection\b"],
"specialty": [r"\bspecialt(y|ies)\b", r"\bservice\b", r"\btype\b", r"\bcategory\b", r"\bkind\b"],
"region": [r"\bzone\b", r"\bregion\b", r"\barea\b", r"\bdistrict\b", r"\bterritory\b"],
# Time-based metrics
"wait_time": [r"\bwait", r"\bdelay", r"\btime", r"\bduration", r"\blength"],
"wait_median": [r"\bmedian\b.*\bwait", r"\bP50\b", r"\bwait.*\bmedian", r"median.*time"],
"wait_p90": [r"\bp90\b", r"\b90(th)?\s*percentile\b", r"\bwait.*p90", r"90.*wait"],
"response_time": [r"\bresponse\b.*\btime\b", r"\bprocessing\b.*\btime\b"],
# Performance metrics
"score": [r"\bscore\b", r"\brating\b", r"\bindex\b", r"\brank\b"],
"efficiency": [r"\befficiency\b", r"\bthroughput\b", r"\bproductivity\b"],
"quality": [r"\bquality\b", r"\bperformance\b", r"\boutcome\b"],
"satisfaction": [r"\bsatisfaction\b", r"\bfeedback\b", r"\brating\b"],
# Capacity metrics
"capacity": [r"\bcapacity\b", r"\bvolume\b", r"\bsize\b", r"\blimit\b"],
"utilization": [r"\butilization\b", r"\boccupancy\b", r"\busage\b"],
"availability": [r"\bavailab\w+", r"\bopen\b", r"\bfree\b"],
# Cost/financial metrics
"cost": [r"\bcost\b", r"\bprice\b", r"\bexpense\b", r"\bfee\b", r"\bcharge\b"],
"budget": [r"\bbudget\b", r"\bfunding\b", r"\ballocation\b"],
"revenue": [r"\brevenue\b", r"\bincome\b", r"\bearnings\b"],
# Count/volume metrics
"count": [r"\bcount\b", r"\bnumber\b", r"\bquantity\b", r"\btotal\b"],
"rate": [r"\brate\b", r"\bratio\b", r"\bpercent\b", r"\bfrequency\b"],
"volume": [r"\bvolume\b", r"\bamount\b", r"\bquantity\b"]
}
def _extract_key_terms_from_scenario(scenario_text: str) -> Set[str]:
"""Extract important terms from scenario text to guide concept detection."""
if not scenario_text:
return set()
# Extract meaningful words, filtering out common stop words
stop_words = {
'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'do', 'does', 'did',
'a', 'an', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they'
}
words = re.findall(r'\b[a-zA-Z]{3,}\b', scenario_text.lower())
key_terms = {word for word in words if word not in stop_words}
return key_terms
def _generate_dynamic_patterns(scenario_terms: Set[str], existing_patterns: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Generate additional concept patterns based on scenario content."""
dynamic_patterns = existing_patterns.copy()
# Add scenario-specific terms as potential concepts
for term in scenario_terms:
if len(term) >= 4: # Only meaningful terms
# Check if term relates to existing concepts
term_pattern = rf"\b{re.escape(term)}\b"
# Add as potential entity if it sounds like one
if any(indicator in term for indicator in ['hospital', 'clinic', 'school', 'department', 'facility']):
if 'facility' not in dynamic_patterns:
dynamic_patterns['facility'] = []
dynamic_patterns['facility'].append(term_pattern)
# Add as potential metric if it sounds like one
elif any(indicator in term for indicator in ['time', 'score', 'rate', 'cost', 'wait']):
concept_key = f"metric_{term}"
dynamic_patterns[concept_key] = [term_pattern]
return dynamic_patterns
def _score_column_match(col_name: str, patterns: List[str], scenario_terms: Set[str] = None) -> int:
"""Score how well a column matches concept patterns."""
col_lower = col_name.lower()
score = 0
# Pattern matching
for i, pattern in enumerate(patterns):
if re.search(pattern, col_lower):
score += 100 - (i * 10) # Higher score for earlier patterns
break
# Boost score if column name contains scenario-relevant terms
if scenario_terms:
for term in scenario_terms:
if term in col_lower:
score += 25
return score
def _detect_column_types(df: pd.DataFrame) -> Dict[str, str]:
"""Detect the likely type/purpose of each column."""
column_types = {}
for col in df.columns:
col_lower = col.lower()
# Detect numeric columns that could be converted
sample = df[col].dropna().head(50)
numeric_convertible = False
if len(sample) > 0:
try:
numeric_sample = pd.to_numeric(sample, errors='coerce')
if numeric_sample.notna().sum() > len(sample) * 0.7:
numeric_convertible = True
except:
pass
# Categorize columns
if numeric_convertible:
if any(term in col_lower for term in ['id', 'number', 'code', 'index']):
column_types[col] = 'identifier'
elif any(term in col_lower for term in ['time', 'date', 'duration', 'wait', 'delay']):
column_types[col] = 'time_metric'
elif any(term in col_lower for term in ['cost', 'price', 'budget', 'fee', 'expense']):
column_types[col] = 'cost_metric'
elif any(term in col_lower for term in ['count', 'number', 'quantity', 'volume']):
column_types[col] = 'count_metric'
elif any(term in col_lower for term in ['rate', 'ratio', 'percent', 'score']):
column_types[col] = 'performance_metric'
else:
column_types[col] = 'numeric_metric'
else:
# String/categorical columns
unique_ratio = df[col].nunique() / len(df)
if unique_ratio < 0.1:
column_types[col] = 'category'
elif unique_ratio < 0.5:
column_types[col] = 'grouping'
else:
column_types[col] = 'text'
return column_types
@dataclass
class MappingResult:
resolved: Dict[str, Tuple[str, str]] = field(default_factory=dict)
ambiguous: Dict[str, List[Tuple[str, str]]] = field(default_factory=dict)
missing: List[str] = field(default_factory=list)
discovered: Dict[str, str] = field(default_factory=dict) # Discovered column types
def _extract_explicit_mappings_from_scenario(scenario_text: str, available_columns: List[Tuple[str, str]]) -> Dict[str, Tuple[str, str]]:
"""Extract explicit column mappings from scenario text."""
explicit_mappings = {}
if not scenario_text:
return explicit_mappings
scenario_lower = scenario_text.lower()
# Create a lookup of available columns (case-insensitive)
column_lookup = {}
for table_name, col_name in available_columns:
column_lookup[col_name.lower()] = (table_name, col_name)
# Pattern 1: Direct column descriptions like "Surgery_Median column contains..."
column_desc_patterns = [
r'(\w+)\s+column\s+(?:contains|reports|shows|includes|represents)',
r'column\s+(\w+)\s+(?:contains|reports|shows|includes|represents)',
r'(\w+)\s+(?:contains|reports|shows|includes|represents)'
]
for pattern in column_desc_patterns:
matches = re.findall(pattern, scenario_text, re.IGNORECASE)
for match in matches:
col_name = match.lower()
if col_name in column_lookup:
# Determine the concept based on context around the column name
context = scenario_text[max(0, scenario_text.lower().find(col_name)-50):scenario_text.lower().find(col_name)+100].lower()
if any(term in context for term in ['wait', 'time', 'delay', 'duration']):
if 'median' in col_name:
explicit_mappings['wait_median'] = column_lookup[col_name]
elif '90' in col_name or 'percentile' in col_name:
explicit_mappings['wait_p90'] = column_lookup[col_name]
else:
explicit_mappings['wait_time'] = column_lookup[col_name]
elif any(term in context for term in ['facility', 'hospital', 'clinic', 'site']):
explicit_mappings['facility'] = column_lookup[col_name]
elif any(term in context for term in ['specialty', 'service', 'department']):
explicit_mappings['specialty'] = column_lookup[col_name]
elif any(term in context for term in ['zone', 'region', 'area', 'district']):
explicit_mappings['region'] = column_lookup[col_name]
# Pattern 2: Task-based column identification like "calculate average for each facility"
task_patterns = [
(r'(?:for each|by)\s+(\w+)', ['facility', 'specialty', 'region']),
(r'(?:identify|rank|list)\s+(\w+)', ['facility', 'specialty', 'region']),
(r'average\s+(\w+)\s+(?:wait|time)', ['wait_median', 'wait_time']),
(r'median\s+(\w+)', ['wait_median']),
(r'90th\s+percentile\s+(\w+)', ['wait_p90'])
]
for pattern, concepts in task_patterns:
matches = re.findall(pattern, scenario_lower)
for match in matches:
match_lower = match.lower()
if match_lower in column_lookup:
for concept in concepts:
if concept not in explicit_mappings:
explicit_mappings[concept] = column_lookup[match_lower]
break
# Pattern 3: Direct column name matches from scenario
explicit_columns = re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', scenario_text)
for col_candidate in explicit_columns:
col_lower = col_candidate.lower()
if col_lower in column_lookup:
# Smart concept assignment based on column name patterns
if not any(concept in explicit_mappings for concept in ['facility', 'organization', 'department']):
if re.search(r'facility|hospital|clinic|site|provider', col_lower):
explicit_mappings['facility'] = column_lookup[col_lower]
if not any(concept in explicit_mappings for concept in ['specialty', 'service']):
if re.search(r'specialty|service|department|type', col_lower):
explicit_mappings['specialty'] = column_lookup[col_lower]
if not any(concept in explicit_mappings for concept in ['region', 'zone']):
if re.search(r'zone|region|area|district', col_lower):
explicit_mappings['region'] = column_lookup[col_lower]
if not any(concept in explicit_mappings for concept in ['wait_median', 'wait_time']):
if re.search(r'.*median.*', col_lower) and re.search(r'wait|time|surgery|consult', col_lower):
explicit_mappings['wait_median'] = column_lookup[col_lower]
if not any(concept in explicit_mappings for concept in ['wait_p90']):
if re.search(r'.*(90|percentile).*', col_lower) and re.search(r'wait|time|surgery|consult', col_lower):
explicit_mappings['wait_p90'] = column_lookup[col_lower]
return explicit_mappings
def _extract_explicit_tasks_from_scenario(scenario_text: str) -> List[str]:
"""Extract explicit task requirements from scenario text."""
tasks = []
if not scenario_text:
return tasks
scenario_lower = scenario_text.lower()
# Task extraction patterns
task_patterns = [
r'(?:your tasks?(?:\s+are)?[:\s]+)([^.]*?)(?:\.|$)',
r'(?:you (?:should|need to|are to|must)[:\s]+)([^.]*?)(?:\.|$)',
r'(?:tasks?[:\s]+)([^.]*?)(?:\.|deliverables|$)',
r'(?:\d+\.?\s*)([^.]*?)(?:\.|$)' # Numbered tasks
]
for pattern in task_patterns:
matches = re.findall(pattern, scenario_text, re.IGNORECASE | re.DOTALL)
for match in matches:
task = match.strip()
if len(task) > 10 and any(verb in task.lower() for verb in ['identify', 'calculate', 'analyze', 'compare', 'assess', 'determine', 'rank', 'list']):
tasks.append(task)
return tasks
def map_concepts(scenario_text: str, registry: DataRegistry) -> MappingResult:
"""Enhanced mapping that extracts explicit information from scenario text."""
result = MappingResult()
if not registry.names():
result.missing = list(UNIVERSAL_CONCEPT_PATTERNS.keys())
return result
# Extract key terms from scenario
scenario_terms = _extract_key_terms_from_scenario(scenario_text)
# Collect all available columns
all_columns = []
for table in registry.iter_tables():
# Detect column types for this table
column_types = _detect_column_types(table.df)
result.discovered.update({f"{table.name}.{col}": col_type for col, col_type in column_types.items()})
for col in table.df.columns:
all_columns.append((table.name, str(col)))
# STEP 1: Extract explicit mappings from scenario text
explicit_mappings = _extract_explicit_mappings_from_scenario(scenario_text, all_columns)
# STEP 2: Use explicit mappings first
for concept, (table_name, col_name) in explicit_mappings.items():
result.resolved[concept] = (table_name, col_name)
# STEP 3: For unmapped concepts, use pattern matching with scenario context
remaining_patterns = {k: v for k, v in UNIVERSAL_CONCEPT_PATTERNS.items() if k not in result.resolved}
if remaining_patterns:
# Generate dynamic patterns based on scenario
concept_patterns = _generate_dynamic_patterns(scenario_terms, remaining_patterns)
# Map remaining concepts to columns
for concept, patterns in concept_patterns.items():
if concept in result.resolved:
continue # Skip already resolved
scores = [
((tbl, col), _score_column_match(col, patterns, scenario_terms))
for (tbl, col) in all_columns
]
scores.sort(key=lambda x: x[1], reverse=True)
if not scores or scores[0][1] == 0:
result.missing.append(concept)
continue
top_score = scores[0][1]
# Find all columns with similar high scores (potential ambiguity)
threshold = max(70, top_score - 15) # Higher threshold for explicit scenarios
high_scoring = [pair for pair, score in scores if score >= threshold]
if len(high_scoring) == 1:
tbl, col = high_scoring[0]
result.resolved[concept] = (tbl, col)
else:
# Check if scenario text makes disambiguation obvious
disambiguated = False
for (tbl, col), score in scores[:3]: # Check top 3
col_mentioned = col.lower() in scenario_text.lower()
if col_mentioned and score >= threshold:
result.resolved[concept] = (tbl, col)
disambiguated = True
break
if not disambiguated:
result.ambiguous[concept] = high_scoring[:3] # Limit to top 3
return result
def build_phase1_questions(scenario_text: str, registry: DataRegistry, mapping: MappingResult, max_questions: int = 4) -> str:
"""Build minimal clarifying questions, only when truly necessary."""
# Extract explicit tasks from scenario
explicit_tasks = _extract_explicit_tasks_from_scenario(scenario_text)
# Check if scenario provides comprehensive instructions
has_detailed_tasks = len(explicit_tasks) >= 3
has_data_descriptions = any(term in scenario_text.lower() for term in [
'column', 'dataset', 'file', 'csv', 'records', 'contains', 'includes'
])
# If scenario is comprehensive, minimize questions
if has_detailed_tasks and has_data_descriptions:
# Only ask about truly ambiguous mappings where scenario doesn't clarify
critical_questions = []
# Only ask about ambiguities that can't be resolved from context
for concept, options in mapping.ambiguous.items():
if len(options) > 1:
# Check if scenario text clearly indicates which column to use
scenario_lower = scenario_text.lower()
clear_preference = None
for table_name, col_name in options:
if col_name.lower() in scenario_lower:
mentions = scenario_lower.count(col_name.lower())
if mentions > 0:
clear_preference = f"{table_name}.{col_name}"
break
if not clear_preference and len(critical_questions) < max_questions:
option_strs = [f"{tbl}.{col}" for tbl, col in options[:3]]
critical_questions.append(f"**Column Clarification**: For {concept.replace('_', ' ')}, use: {', '.join(option_strs)}?")
if not critical_questions:
return "**Proceeding with Analysis**: Scenario and data mappings are clear. Analyzing now..."
return "**Quick Clarification**\n\n" + "\n".join(critical_questions)
# Fallback to standard question generation for less comprehensive scenarios
questions = []
scenario_lower = scenario_text.lower() if scenario_text else ""
# Ambiguous mappings - ask for clarification
important_concepts = ['facility', 'organization', 'department', 'specialty', 'region']
for concept in important_concepts:
if concept in mapping.ambiguous and len(questions) < max_questions:
options = [f"{tbl}.{col}" for tbl, col in mapping.ambiguous[concept][:3]]
questions.append(f"**Entity**: Which column represents {concept.replace('_', ' ')}? Options: {', '.join(options)}")
# Missing critical data
if len(questions) < max_questions:
if not any(concept in mapping.resolved for concept in ['facility', 'organization', 'department']):
questions.append("**Grouping**: What entities should be analyzed? (facilities, departments, regions, etc.)")
if not any(concept in mapping.resolved for concept in ['wait_time', 'wait_median', 'score', 'performance']):
questions.append("**Metric**: What is the primary metric to analyze? (wait times, scores, costs, etc.)")
if not questions:
return "**Analysis Ready**: Data structure understood. Proceeding with analysis..."
return "**Clarification Questions**\n\n" + "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions))