Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Any, Tuple, Optional, Set | |
| import pandas as pd | |
| from data_registry import DataRegistry | |
| # Generic concept patterns that work across domains | |
| UNIVERSAL_CONCEPT_PATTERNS = { | |
| # Entity/grouping concepts | |
| "facility": [r"\bfacilit(y|ies)\b", r"\bhospital\b", r"\bsite\b", r"\bcentre\b", r"\bcenter\b", r"\blocation\b", r"\bprovider\b"], | |
| "organization": [r"\borganization\b", r"\bcompany\b", r"\bbusiness\b", r"\bfirm\b", r"\bentity\b"], | |
| "department": [r"\bdepartment\b", r"\bdivision\b", r"\bunit\b", r"\bsection\b"], | |
| "specialty": [r"\bspecialt(y|ies)\b", r"\bservice\b", r"\btype\b", r"\bcategory\b", r"\bkind\b"], | |
| "region": [r"\bzone\b", r"\bregion\b", r"\barea\b", r"\bdistrict\b", r"\bterritory\b"], | |
| # Time-based metrics | |
| "wait_time": [r"\bwait", r"\bdelay", r"\btime", r"\bduration", r"\blength"], | |
| "wait_median": [r"\bmedian\b.*\bwait", r"\bP50\b", r"\bwait.*\bmedian", r"median.*time"], | |
| "wait_p90": [r"\bp90\b", r"\b90(th)?\s*percentile\b", r"\bwait.*p90", r"90.*wait"], | |
| "response_time": [r"\bresponse\b.*\btime\b", r"\bprocessing\b.*\btime\b"], | |
| # Performance metrics | |
| "score": [r"\bscore\b", r"\brating\b", r"\bindex\b", r"\brank\b"], | |
| "efficiency": [r"\befficiency\b", r"\bthroughput\b", r"\bproductivity\b"], | |
| "quality": [r"\bquality\b", r"\bperformance\b", r"\boutcome\b"], | |
| "satisfaction": [r"\bsatisfaction\b", r"\bfeedback\b", r"\brating\b"], | |
| # Capacity metrics | |
| "capacity": [r"\bcapacity\b", r"\bvolume\b", r"\bsize\b", r"\blimit\b"], | |
| "utilization": [r"\butilization\b", r"\boccupancy\b", r"\busage\b"], | |
| "availability": [r"\bavailab\w+", r"\bopen\b", r"\bfree\b"], | |
| # Cost/financial metrics | |
| "cost": [r"\bcost\b", r"\bprice\b", r"\bexpense\b", r"\bfee\b", r"\bcharge\b"], | |
| "budget": [r"\bbudget\b", r"\bfunding\b", r"\ballocation\b"], | |
| "revenue": [r"\brevenue\b", r"\bincome\b", r"\bearnings\b"], | |
| # Count/volume metrics | |
| "count": [r"\bcount\b", r"\bnumber\b", r"\bquantity\b", r"\btotal\b"], | |
| "rate": [r"\brate\b", r"\bratio\b", r"\bpercent\b", r"\bfrequency\b"], | |
| "volume": [r"\bvolume\b", r"\bamount\b", r"\bquantity\b"] | |
| } | |
| def _extract_key_terms_from_scenario(scenario_text: str) -> Set[str]: | |
| """Extract important terms from scenario text to guide concept detection.""" | |
| if not scenario_text: | |
| return set() | |
| # Extract meaningful words, filtering out common stop words | |
| stop_words = { | |
| 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', | |
| 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'do', 'does', 'did', | |
| 'a', 'an', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they' | |
| } | |
| words = re.findall(r'\b[a-zA-Z]{3,}\b', scenario_text.lower()) | |
| key_terms = {word for word in words if word not in stop_words} | |
| return key_terms | |
| def _generate_dynamic_patterns(scenario_terms: Set[str], existing_patterns: Dict[str, List[str]]) -> Dict[str, List[str]]: | |
| """Generate additional concept patterns based on scenario content.""" | |
| dynamic_patterns = existing_patterns.copy() | |
| # Add scenario-specific terms as potential concepts | |
| for term in scenario_terms: | |
| if len(term) >= 4: # Only meaningful terms | |
| # Check if term relates to existing concepts | |
| term_pattern = rf"\b{re.escape(term)}\b" | |
| # Add as potential entity if it sounds like one | |
| if any(indicator in term for indicator in ['hospital', 'clinic', 'school', 'department', 'facility']): | |
| if 'facility' not in dynamic_patterns: | |
| dynamic_patterns['facility'] = [] | |
| dynamic_patterns['facility'].append(term_pattern) | |
| # Add as potential metric if it sounds like one | |
| elif any(indicator in term for indicator in ['time', 'score', 'rate', 'cost', 'wait']): | |
| concept_key = f"metric_{term}" | |
| dynamic_patterns[concept_key] = [term_pattern] | |
| return dynamic_patterns | |
| def _score_column_match(col_name: str, patterns: List[str], scenario_terms: Set[str] = None) -> int: | |
| """Score how well a column matches concept patterns.""" | |
| col_lower = col_name.lower() | |
| score = 0 | |
| # Pattern matching | |
| for i, pattern in enumerate(patterns): | |
| if re.search(pattern, col_lower): | |
| score += 100 - (i * 10) # Higher score for earlier patterns | |
| break | |
| # Boost score if column name contains scenario-relevant terms | |
| if scenario_terms: | |
| for term in scenario_terms: | |
| if term in col_lower: | |
| score += 25 | |
| return score | |
| def _detect_column_types(df: pd.DataFrame) -> Dict[str, str]: | |
| """Detect the likely type/purpose of each column.""" | |
| column_types = {} | |
| for col in df.columns: | |
| col_lower = col.lower() | |
| # Detect numeric columns that could be converted | |
| sample = df[col].dropna().head(50) | |
| numeric_convertible = False | |
| if len(sample) > 0: | |
| try: | |
| numeric_sample = pd.to_numeric(sample, errors='coerce') | |
| if numeric_sample.notna().sum() > len(sample) * 0.7: | |
| numeric_convertible = True | |
| except: | |
| pass | |
| # Categorize columns | |
| if numeric_convertible: | |
| if any(term in col_lower for term in ['id', 'number', 'code', 'index']): | |
| column_types[col] = 'identifier' | |
| elif any(term in col_lower for term in ['time', 'date', 'duration', 'wait', 'delay']): | |
| column_types[col] = 'time_metric' | |
| elif any(term in col_lower for term in ['cost', 'price', 'budget', 'fee', 'expense']): | |
| column_types[col] = 'cost_metric' | |
| elif any(term in col_lower for term in ['count', 'number', 'quantity', 'volume']): | |
| column_types[col] = 'count_metric' | |
| elif any(term in col_lower for term in ['rate', 'ratio', 'percent', 'score']): | |
| column_types[col] = 'performance_metric' | |
| else: | |
| column_types[col] = 'numeric_metric' | |
| else: | |
| # String/categorical columns | |
| unique_ratio = df[col].nunique() / len(df) | |
| if unique_ratio < 0.1: | |
| column_types[col] = 'category' | |
| elif unique_ratio < 0.5: | |
| column_types[col] = 'grouping' | |
| else: | |
| column_types[col] = 'text' | |
| return column_types | |
| class MappingResult: | |
| resolved: Dict[str, Tuple[str, str]] = field(default_factory=dict) | |
| ambiguous: Dict[str, List[Tuple[str, str]]] = field(default_factory=dict) | |
| missing: List[str] = field(default_factory=list) | |
| discovered: Dict[str, str] = field(default_factory=dict) # Discovered column types | |
| def _extract_explicit_mappings_from_scenario(scenario_text: str, available_columns: List[Tuple[str, str]]) -> Dict[str, Tuple[str, str]]: | |
| """Extract explicit column mappings from scenario text.""" | |
| explicit_mappings = {} | |
| if not scenario_text: | |
| return explicit_mappings | |
| scenario_lower = scenario_text.lower() | |
| # Create a lookup of available columns (case-insensitive) | |
| column_lookup = {} | |
| for table_name, col_name in available_columns: | |
| column_lookup[col_name.lower()] = (table_name, col_name) | |
| # Pattern 1: Direct column descriptions like "Surgery_Median column contains..." | |
| column_desc_patterns = [ | |
| r'(\w+)\s+column\s+(?:contains|reports|shows|includes|represents)', | |
| r'column\s+(\w+)\s+(?:contains|reports|shows|includes|represents)', | |
| r'(\w+)\s+(?:contains|reports|shows|includes|represents)' | |
| ] | |
| for pattern in column_desc_patterns: | |
| matches = re.findall(pattern, scenario_text, re.IGNORECASE) | |
| for match in matches: | |
| col_name = match.lower() | |
| if col_name in column_lookup: | |
| # Determine the concept based on context around the column name | |
| context = scenario_text[max(0, scenario_text.lower().find(col_name)-50):scenario_text.lower().find(col_name)+100].lower() | |
| if any(term in context for term in ['wait', 'time', 'delay', 'duration']): | |
| if 'median' in col_name: | |
| explicit_mappings['wait_median'] = column_lookup[col_name] | |
| elif '90' in col_name or 'percentile' in col_name: | |
| explicit_mappings['wait_p90'] = column_lookup[col_name] | |
| else: | |
| explicit_mappings['wait_time'] = column_lookup[col_name] | |
| elif any(term in context for term in ['facility', 'hospital', 'clinic', 'site']): | |
| explicit_mappings['facility'] = column_lookup[col_name] | |
| elif any(term in context for term in ['specialty', 'service', 'department']): | |
| explicit_mappings['specialty'] = column_lookup[col_name] | |
| elif any(term in context for term in ['zone', 'region', 'area', 'district']): | |
| explicit_mappings['region'] = column_lookup[col_name] | |
| # Pattern 2: Task-based column identification like "calculate average for each facility" | |
| task_patterns = [ | |
| (r'(?:for each|by)\s+(\w+)', ['facility', 'specialty', 'region']), | |
| (r'(?:identify|rank|list)\s+(\w+)', ['facility', 'specialty', 'region']), | |
| (r'average\s+(\w+)\s+(?:wait|time)', ['wait_median', 'wait_time']), | |
| (r'median\s+(\w+)', ['wait_median']), | |
| (r'90th\s+percentile\s+(\w+)', ['wait_p90']) | |
| ] | |
| for pattern, concepts in task_patterns: | |
| matches = re.findall(pattern, scenario_lower) | |
| for match in matches: | |
| match_lower = match.lower() | |
| if match_lower in column_lookup: | |
| for concept in concepts: | |
| if concept not in explicit_mappings: | |
| explicit_mappings[concept] = column_lookup[match_lower] | |
| break | |
| # Pattern 3: Direct column name matches from scenario | |
| explicit_columns = re.findall(r'\b([A-Za-z_][A-Za-z0-9_]*)\b', scenario_text) | |
| for col_candidate in explicit_columns: | |
| col_lower = col_candidate.lower() | |
| if col_lower in column_lookup: | |
| # Smart concept assignment based on column name patterns | |
| if not any(concept in explicit_mappings for concept in ['facility', 'organization', 'department']): | |
| if re.search(r'facility|hospital|clinic|site|provider', col_lower): | |
| explicit_mappings['facility'] = column_lookup[col_lower] | |
| if not any(concept in explicit_mappings for concept in ['specialty', 'service']): | |
| if re.search(r'specialty|service|department|type', col_lower): | |
| explicit_mappings['specialty'] = column_lookup[col_lower] | |
| if not any(concept in explicit_mappings for concept in ['region', 'zone']): | |
| if re.search(r'zone|region|area|district', col_lower): | |
| explicit_mappings['region'] = column_lookup[col_lower] | |
| if not any(concept in explicit_mappings for concept in ['wait_median', 'wait_time']): | |
| if re.search(r'.*median.*', col_lower) and re.search(r'wait|time|surgery|consult', col_lower): | |
| explicit_mappings['wait_median'] = column_lookup[col_lower] | |
| if not any(concept in explicit_mappings for concept in ['wait_p90']): | |
| if re.search(r'.*(90|percentile).*', col_lower) and re.search(r'wait|time|surgery|consult', col_lower): | |
| explicit_mappings['wait_p90'] = column_lookup[col_lower] | |
| return explicit_mappings | |
| def _extract_explicit_tasks_from_scenario(scenario_text: str) -> List[str]: | |
| """Extract explicit task requirements from scenario text.""" | |
| tasks = [] | |
| if not scenario_text: | |
| return tasks | |
| scenario_lower = scenario_text.lower() | |
| # Task extraction patterns | |
| task_patterns = [ | |
| r'(?:your tasks?(?:\s+are)?[:\s]+)([^.]*?)(?:\.|$)', | |
| r'(?:you (?:should|need to|are to|must)[:\s]+)([^.]*?)(?:\.|$)', | |
| r'(?:tasks?[:\s]+)([^.]*?)(?:\.|deliverables|$)', | |
| r'(?:\d+\.?\s*)([^.]*?)(?:\.|$)' # Numbered tasks | |
| ] | |
| for pattern in task_patterns: | |
| matches = re.findall(pattern, scenario_text, re.IGNORECASE | re.DOTALL) | |
| for match in matches: | |
| task = match.strip() | |
| if len(task) > 10 and any(verb in task.lower() for verb in ['identify', 'calculate', 'analyze', 'compare', 'assess', 'determine', 'rank', 'list']): | |
| tasks.append(task) | |
| return tasks | |
| def map_concepts(scenario_text: str, registry: DataRegistry) -> MappingResult: | |
| """Enhanced mapping that extracts explicit information from scenario text.""" | |
| result = MappingResult() | |
| if not registry.names(): | |
| result.missing = list(UNIVERSAL_CONCEPT_PATTERNS.keys()) | |
| return result | |
| # Extract key terms from scenario | |
| scenario_terms = _extract_key_terms_from_scenario(scenario_text) | |
| # Collect all available columns | |
| all_columns = [] | |
| for table in registry.iter_tables(): | |
| # Detect column types for this table | |
| column_types = _detect_column_types(table.df) | |
| result.discovered.update({f"{table.name}.{col}": col_type for col, col_type in column_types.items()}) | |
| for col in table.df.columns: | |
| all_columns.append((table.name, str(col))) | |
| # STEP 1: Extract explicit mappings from scenario text | |
| explicit_mappings = _extract_explicit_mappings_from_scenario(scenario_text, all_columns) | |
| # STEP 2: Use explicit mappings first | |
| for concept, (table_name, col_name) in explicit_mappings.items(): | |
| result.resolved[concept] = (table_name, col_name) | |
| # STEP 3: For unmapped concepts, use pattern matching with scenario context | |
| remaining_patterns = {k: v for k, v in UNIVERSAL_CONCEPT_PATTERNS.items() if k not in result.resolved} | |
| if remaining_patterns: | |
| # Generate dynamic patterns based on scenario | |
| concept_patterns = _generate_dynamic_patterns(scenario_terms, remaining_patterns) | |
| # Map remaining concepts to columns | |
| for concept, patterns in concept_patterns.items(): | |
| if concept in result.resolved: | |
| continue # Skip already resolved | |
| scores = [ | |
| ((tbl, col), _score_column_match(col, patterns, scenario_terms)) | |
| for (tbl, col) in all_columns | |
| ] | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| if not scores or scores[0][1] == 0: | |
| result.missing.append(concept) | |
| continue | |
| top_score = scores[0][1] | |
| # Find all columns with similar high scores (potential ambiguity) | |
| threshold = max(70, top_score - 15) # Higher threshold for explicit scenarios | |
| high_scoring = [pair for pair, score in scores if score >= threshold] | |
| if len(high_scoring) == 1: | |
| tbl, col = high_scoring[0] | |
| result.resolved[concept] = (tbl, col) | |
| else: | |
| # Check if scenario text makes disambiguation obvious | |
| disambiguated = False | |
| for (tbl, col), score in scores[:3]: # Check top 3 | |
| col_mentioned = col.lower() in scenario_text.lower() | |
| if col_mentioned and score >= threshold: | |
| result.resolved[concept] = (tbl, col) | |
| disambiguated = True | |
| break | |
| if not disambiguated: | |
| result.ambiguous[concept] = high_scoring[:3] # Limit to top 3 | |
| return result | |
| def build_phase1_questions(scenario_text: str, registry: DataRegistry, mapping: MappingResult, max_questions: int = 4) -> str: | |
| """Build minimal clarifying questions, only when truly necessary.""" | |
| # Extract explicit tasks from scenario | |
| explicit_tasks = _extract_explicit_tasks_from_scenario(scenario_text) | |
| # Check if scenario provides comprehensive instructions | |
| has_detailed_tasks = len(explicit_tasks) >= 3 | |
| has_data_descriptions = any(term in scenario_text.lower() for term in [ | |
| 'column', 'dataset', 'file', 'csv', 'records', 'contains', 'includes' | |
| ]) | |
| # If scenario is comprehensive, minimize questions | |
| if has_detailed_tasks and has_data_descriptions: | |
| # Only ask about truly ambiguous mappings where scenario doesn't clarify | |
| critical_questions = [] | |
| # Only ask about ambiguities that can't be resolved from context | |
| for concept, options in mapping.ambiguous.items(): | |
| if len(options) > 1: | |
| # Check if scenario text clearly indicates which column to use | |
| scenario_lower = scenario_text.lower() | |
| clear_preference = None | |
| for table_name, col_name in options: | |
| if col_name.lower() in scenario_lower: | |
| mentions = scenario_lower.count(col_name.lower()) | |
| if mentions > 0: | |
| clear_preference = f"{table_name}.{col_name}" | |
| break | |
| if not clear_preference and len(critical_questions) < max_questions: | |
| option_strs = [f"{tbl}.{col}" for tbl, col in options[:3]] | |
| critical_questions.append(f"**Column Clarification**: For {concept.replace('_', ' ')}, use: {', '.join(option_strs)}?") | |
| if not critical_questions: | |
| return "**Proceeding with Analysis**: Scenario and data mappings are clear. Analyzing now..." | |
| return "**Quick Clarification**\n\n" + "\n".join(critical_questions) | |
| # Fallback to standard question generation for less comprehensive scenarios | |
| questions = [] | |
| scenario_lower = scenario_text.lower() if scenario_text else "" | |
| # Ambiguous mappings - ask for clarification | |
| important_concepts = ['facility', 'organization', 'department', 'specialty', 'region'] | |
| for concept in important_concepts: | |
| if concept in mapping.ambiguous and len(questions) < max_questions: | |
| options = [f"{tbl}.{col}" for tbl, col in mapping.ambiguous[concept][:3]] | |
| questions.append(f"**Entity**: Which column represents {concept.replace('_', ' ')}? Options: {', '.join(options)}") | |
| # Missing critical data | |
| if len(questions) < max_questions: | |
| if not any(concept in mapping.resolved for concept in ['facility', 'organization', 'department']): | |
| questions.append("**Grouping**: What entities should be analyzed? (facilities, departments, regions, etc.)") | |
| if not any(concept in mapping.resolved for concept in ['wait_time', 'wait_median', 'score', 'performance']): | |
| questions.append("**Metric**: What is the primary metric to analyze? (wait times, scores, costs, etc.)") | |
| if not questions: | |
| return "**Analysis Ready**: Data structure understood. Proceeding with analysis..." | |
| return "**Clarification Questions**\n\n" + "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions)) |