Update sozo_gen.py
Browse files- sozo_gen.py +271 -28
sozo_gen.py
CHANGED
|
@@ -310,11 +310,11 @@ def perform_autonomous_data_analysis(df: pd.DataFrame, user_ctx: str, filename:
|
|
| 310 |
"""
|
| 311 |
logging.info("Performing autonomous data analysis...")
|
| 312 |
|
| 313 |
-
# Basic data profiling
|
| 314 |
basic_info = {
|
| 315 |
"shape": df.shape,
|
| 316 |
"columns": list(df.columns),
|
| 317 |
-
"dtypes": df.dtypes.
|
| 318 |
"filename": filename,
|
| 319 |
"user_context": user_ctx
|
| 320 |
}
|
|
@@ -388,11 +388,11 @@ def classify_dataset_domain(df: pd.DataFrame, filename: str) -> Dict[str, Any]:
|
|
| 388 |
|
| 389 |
return {
|
| 390 |
"primary_domain": primary_domain,
|
| 391 |
-
"domain_confidence": domain_scores.get(primary_domain, 0),
|
| 392 |
-
"domain_scores": domain_scores,
|
| 393 |
"data_characteristics": {
|
| 394 |
-
"numeric_ratio": numeric_ratio,
|
| 395 |
-
"categorical_ratio": categorical_ratio,
|
| 396 |
"is_time_series": detect_time_series(df),
|
| 397 |
"is_transactional": detect_transactional_data(df),
|
| 398 |
"is_experimental": detect_experimental_data(df)
|
|
@@ -412,12 +412,20 @@ def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 412 |
"missing_data": {}
|
| 413 |
}
|
| 414 |
|
| 415 |
-
# Summary statistics for numeric columns
|
| 416 |
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 417 |
if len(numeric_cols) > 0:
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
-
# Correlation analysis
|
| 421 |
if len(numeric_cols) > 1:
|
| 422 |
corr_matrix = df[numeric_cols].corr()
|
| 423 |
# Find strong correlations
|
|
@@ -425,11 +433,11 @@ def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 425 |
for i in range(len(corr_matrix.columns)):
|
| 426 |
for j in range(i+1, len(corr_matrix.columns)):
|
| 427 |
corr_val = corr_matrix.iloc[i, j]
|
| 428 |
-
if abs(corr_val) > 0.7: # Strong correlation threshold
|
| 429 |
strong_corrs.append({
|
| 430 |
"var1": corr_matrix.columns[i],
|
| 431 |
"var2": corr_matrix.columns[j],
|
| 432 |
-
"correlation": corr_val
|
| 433 |
})
|
| 434 |
profile["correlations"] = {"strong_correlations": strong_corrs}
|
| 435 |
|
|
@@ -438,16 +446,22 @@ def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 438 |
if len(categorical_cols) > 0:
|
| 439 |
profile["categorical_analysis"] = {}
|
| 440 |
for col in categorical_cols:
|
|
|
|
| 441 |
profile["categorical_analysis"][col] = {
|
| 442 |
-
"unique_count": df[col].nunique(),
|
| 443 |
-
"top_values":
|
| 444 |
}
|
| 445 |
|
| 446 |
-
# Missing data analysis
|
| 447 |
missing_data = df.isnull().sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
profile["missing_data"] = {
|
| 449 |
-
"columns_with_missing":
|
| 450 |
-
"total_missing_percentage": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100
|
| 451 |
}
|
| 452 |
|
| 453 |
return profile
|
|
@@ -472,11 +486,11 @@ def discover_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 472 |
for col2 in numeric_cols:
|
| 473 |
if col1 != col2:
|
| 474 |
correlation = df[col1].corr(df[col2])
|
| 475 |
-
if abs(correlation) > 0.5: # Moderate to strong correlation
|
| 476 |
relationships["key_relationships"].append({
|
| 477 |
"variable1": col1,
|
| 478 |
"variable2": col2,
|
| 479 |
-
"relationship_strength": correlation,
|
| 480 |
"relationship_type": "positive" if correlation > 0 else "negative"
|
| 481 |
})
|
| 482 |
|
|
@@ -489,8 +503,8 @@ def discover_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 489 |
relationships["patterns"].append({
|
| 490 |
"column": col,
|
| 491 |
"pattern_type": "categorical_distribution",
|
| 492 |
-
"dominant_category": value_counts.index[0],
|
| 493 |
-
"dominance_percentage": (value_counts.iloc[0] / len(df)) * 100
|
| 494 |
})
|
| 495 |
|
| 496 |
return relationships
|
|
@@ -526,7 +540,7 @@ def analyze_temporal_patterns(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 526 |
"start": df_temp[primary_date_col].min().strftime('%Y-%m-%d'),
|
| 527 |
"end": df_temp[primary_date_col].max().strftime('%Y-%m-%d')
|
| 528 |
},
|
| 529 |
-
"time_span_days": (df_temp[primary_date_col].max() - df_temp[primary_date_col].min()).days,
|
| 530 |
"frequency": detect_temporal_frequency(df_temp[primary_date_col])
|
| 531 |
}
|
| 532 |
|
|
@@ -544,16 +558,16 @@ def assess_data_quality(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 544 |
"data_consistency": {}
|
| 545 |
}
|
| 546 |
|
| 547 |
-
# Completeness assessment
|
| 548 |
-
completeness = (1 - df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100
|
| 549 |
quality_metrics["data_completeness"] = completeness
|
| 550 |
|
| 551 |
# Identify quality issues
|
| 552 |
if completeness < 95:
|
| 553 |
quality_metrics["quality_issues"].append("Missing data detected")
|
| 554 |
|
| 555 |
-
# Check for duplicates
|
| 556 |
-
duplicate_rows = df.duplicated().sum()
|
| 557 |
if duplicate_rows > 0:
|
| 558 |
quality_metrics["quality_issues"].append(f"{duplicate_rows} duplicate rows found")
|
| 559 |
|
|
@@ -563,11 +577,11 @@ def assess_data_quality(df: pd.DataFrame) -> Dict[str, Any]:
|
|
| 563 |
if df[col].str.isnumeric().any() and not df[col].str.isnumeric().all():
|
| 564 |
quality_metrics["quality_issues"].append(f"Inconsistent data types in {col}")
|
| 565 |
|
| 566 |
-
# Calculate overall quality score
|
| 567 |
-
base_score = 100
|
| 568 |
base_score -= (100 - completeness) * 0.5 # Penalize missing data
|
| 569 |
base_score -= len(quality_metrics["quality_issues"]) * 5 # Penalize each quality issue
|
| 570 |
-
quality_metrics["overall_quality_score"] = max(0, base_score)
|
| 571 |
|
| 572 |
return quality_metrics
|
| 573 |
|
|
@@ -839,6 +853,235 @@ def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_
|
|
| 839 |
return {"raw_md": md, "chartUrls": chart_urls}
|
| 840 |
|
| 841 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
|
| 843 |
"""
|
| 844 |
Generates a basic fallback report when enhanced generation fails.
|
|
|
|
| 310 |
"""
|
| 311 |
logging.info("Performing autonomous data analysis...")
|
| 312 |
|
| 313 |
+
# Basic data profiling with JSON-safe types
|
| 314 |
basic_info = {
|
| 315 |
"shape": df.shape,
|
| 316 |
"columns": list(df.columns),
|
| 317 |
+
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
|
| 318 |
"filename": filename,
|
| 319 |
"user_context": user_ctx
|
| 320 |
}
|
|
|
|
| 388 |
|
| 389 |
return {
|
| 390 |
"primary_domain": primary_domain,
|
| 391 |
+
"domain_confidence": int(domain_scores.get(primary_domain, 0)),
|
| 392 |
+
"domain_scores": {k: int(v) for k, v in domain_scores.items()},
|
| 393 |
"data_characteristics": {
|
| 394 |
+
"numeric_ratio": float(numeric_ratio),
|
| 395 |
+
"categorical_ratio": float(categorical_ratio),
|
| 396 |
"is_time_series": detect_time_series(df),
|
| 397 |
"is_transactional": detect_transactional_data(df),
|
| 398 |
"is_experimental": detect_experimental_data(df)
|
|
|
|
| 412 |
"missing_data": {}
|
| 413 |
}
|
| 414 |
|
| 415 |
+
# Summary statistics for numeric columns with JSON-safe conversion
|
| 416 |
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 417 |
if len(numeric_cols) > 0:
|
| 418 |
+
desc_stats = df[numeric_cols].describe()
|
| 419 |
+
# Convert to JSON-safe format
|
| 420 |
+
profile["summary_stats"] = {
|
| 421 |
+
col: {
|
| 422 |
+
stat: float(val) if pd.notna(val) else None
|
| 423 |
+
for stat, val in desc_stats[col].items()
|
| 424 |
+
}
|
| 425 |
+
for col in desc_stats.columns
|
| 426 |
+
}
|
| 427 |
|
| 428 |
+
# Correlation analysis with JSON-safe conversion
|
| 429 |
if len(numeric_cols) > 1:
|
| 430 |
corr_matrix = df[numeric_cols].corr()
|
| 431 |
# Find strong correlations
|
|
|
|
| 433 |
for i in range(len(corr_matrix.columns)):
|
| 434 |
for j in range(i+1, len(corr_matrix.columns)):
|
| 435 |
corr_val = corr_matrix.iloc[i, j]
|
| 436 |
+
if abs(corr_val) > 0.7 and pd.notna(corr_val): # Strong correlation threshold
|
| 437 |
strong_corrs.append({
|
| 438 |
"var1": corr_matrix.columns[i],
|
| 439 |
"var2": corr_matrix.columns[j],
|
| 440 |
+
"correlation": float(corr_val)
|
| 441 |
})
|
| 442 |
profile["correlations"] = {"strong_correlations": strong_corrs}
|
| 443 |
|
|
|
|
| 446 |
if len(categorical_cols) > 0:
|
| 447 |
profile["categorical_analysis"] = {}
|
| 448 |
for col in categorical_cols:
|
| 449 |
+
value_counts = df[col].value_counts().head(5)
|
| 450 |
profile["categorical_analysis"][col] = {
|
| 451 |
+
"unique_count": int(df[col].nunique()),
|
| 452 |
+
"top_values": {str(k): int(v) for k, v in value_counts.items()}
|
| 453 |
}
|
| 454 |
|
| 455 |
+
# Missing data analysis with JSON-safe conversion
|
| 456 |
missing_data = df.isnull().sum()
|
| 457 |
+
missing_dict = {}
|
| 458 |
+
for col, missing_count in missing_data.items():
|
| 459 |
+
if missing_count > 0:
|
| 460 |
+
missing_dict[col] = int(missing_count)
|
| 461 |
+
|
| 462 |
profile["missing_data"] = {
|
| 463 |
+
"columns_with_missing": missing_dict,
|
| 464 |
+
"total_missing_percentage": float((df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100)
|
| 465 |
}
|
| 466 |
|
| 467 |
return profile
|
|
|
|
| 486 |
for col2 in numeric_cols:
|
| 487 |
if col1 != col2:
|
| 488 |
correlation = df[col1].corr(df[col2])
|
| 489 |
+
if abs(correlation) > 0.5 and pd.notna(correlation): # Moderate to strong correlation
|
| 490 |
relationships["key_relationships"].append({
|
| 491 |
"variable1": col1,
|
| 492 |
"variable2": col2,
|
| 493 |
+
"relationship_strength": float(correlation),
|
| 494 |
"relationship_type": "positive" if correlation > 0 else "negative"
|
| 495 |
})
|
| 496 |
|
|
|
|
| 503 |
relationships["patterns"].append({
|
| 504 |
"column": col,
|
| 505 |
"pattern_type": "categorical_distribution",
|
| 506 |
+
"dominant_category": str(value_counts.index[0]),
|
| 507 |
+
"dominance_percentage": float((value_counts.iloc[0] / len(df)) * 100)
|
| 508 |
})
|
| 509 |
|
| 510 |
return relationships
|
|
|
|
| 540 |
"start": df_temp[primary_date_col].min().strftime('%Y-%m-%d'),
|
| 541 |
"end": df_temp[primary_date_col].max().strftime('%Y-%m-%d')
|
| 542 |
},
|
| 543 |
+
"time_span_days": int((df_temp[primary_date_col].max() - df_temp[primary_date_col].min()).days),
|
| 544 |
"frequency": detect_temporal_frequency(df_temp[primary_date_col])
|
| 545 |
}
|
| 546 |
|
|
|
|
| 558 |
"data_consistency": {}
|
| 559 |
}
|
| 560 |
|
| 561 |
+
# Completeness assessment with JSON-safe conversion
|
| 562 |
+
completeness = float((1 - df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100)
|
| 563 |
quality_metrics["data_completeness"] = completeness
|
| 564 |
|
| 565 |
# Identify quality issues
|
| 566 |
if completeness < 95:
|
| 567 |
quality_metrics["quality_issues"].append("Missing data detected")
|
| 568 |
|
| 569 |
+
# Check for duplicates with JSON-safe conversion
|
| 570 |
+
duplicate_rows = int(df.duplicated().sum())
|
| 571 |
if duplicate_rows > 0:
|
| 572 |
quality_metrics["quality_issues"].append(f"{duplicate_rows} duplicate rows found")
|
| 573 |
|
|
|
|
| 577 |
if df[col].str.isnumeric().any() and not df[col].str.isnumeric().all():
|
| 578 |
quality_metrics["quality_issues"].append(f"Inconsistent data types in {col}")
|
| 579 |
|
| 580 |
+
# Calculate overall quality score with JSON-safe conversion
|
| 581 |
+
base_score = 100.0
|
| 582 |
base_score -= (100 - completeness) * 0.5 # Penalize missing data
|
| 583 |
base_score -= len(quality_metrics["quality_issues"]) * 5 # Penalize each quality issue
|
| 584 |
+
quality_metrics["overall_quality_score"] = float(max(0, base_score))
|
| 585 |
|
| 586 |
return quality_metrics
|
| 587 |
|
|
|
|
| 853 |
return {"raw_md": md, "chartUrls": chart_urls}
|
| 854 |
|
| 855 |
|
| 856 |
+
def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
|
| 857 |
+
"""
|
| 858 |
+
Generates a basic fallback report when enhanced generation fails.
|
| 859 |
+
"""
|
| 860 |
+
basic_info = autonomous_context["basic_info"]
|
| 861 |
+
domain = autonomous_context["domain"]["primary_domain"]
|
| 862 |
+
|
| 863 |
+
return f"""
|
| 864 |
+
# What This Data Reveals
|
| 865 |
+
|
| 866 |
+
Looking at this {domain} dataset with {basic_info['shape'][0]} records, there are several key insights worth highlighting.
|
| 867 |
+
|
| 868 |
+
## The Numbers Tell a Story
|
| 869 |
+
|
| 870 |
+
This dataset contains {basic_info['shape'][1]} different variables, suggesting a comprehensive view of the underlying processes or behaviors being measured.
|
| 871 |
+
|
| 872 |
+
<generate_chart: "bar | Data overview showing key metrics">
|
| 873 |
+
|
| 874 |
+
## What You Should Know
|
| 875 |
+
|
| 876 |
+
The data structure and patterns suggest this is worth deeper investigation. The variety of data types and relationships indicate multiple analytical opportunities.
|
| 877 |
+
|
| 878 |
+
## Next Steps
|
| 879 |
+
|
| 880 |
+
Based on this initial analysis, I recommend diving deeper into the specific patterns and relationships within the data to unlock more actionable insights.
|
| 881 |
+
|
| 882 |
+
*Note: This is a simplified analysis. Enhanced storytelling temporarily unavailable.*
|
| 883 |
+
"""
|
| 884 |
+
# Removed - no longer needed since we're letting AI decide everything organically
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def generate_autonomous_charts(llm, df: pd.DataFrame, report_md: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
|
| 888 |
+
"""
|
| 889 |
+
Generates charts autonomously based on the report content and data characteristics.
|
| 890 |
+
"""
|
| 891 |
+
# Extract chart descriptions from the enhanced report
|
| 892 |
+
chart_descs = extract_chart_tags(report_md)[:MAX_CHARTS]
|
| 893 |
+
chart_urls = {}
|
| 894 |
+
|
| 895 |
+
if not chart_descs:
|
| 896 |
+
# If no charts specified, generate intelligent defaults
|
| 897 |
+
chart_descs = generate_intelligent_chart_suggestions(df, llm)
|
| 898 |
+
|
| 899 |
+
chart_generator = ChartGenerator(llm, df)
|
| 900 |
+
|
| 901 |
+
for desc in chart_descs:
|
| 902 |
+
try:
|
| 903 |
+
# Create a safe key for Firebase
|
| 904 |
+
safe_desc = sanitize_for_firebase_key(desc)
|
| 905 |
+
|
| 906 |
+
# Replace chart tags in markdown
|
| 907 |
+
report_md = report_md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 908 |
+
report_md = report_md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 909 |
+
|
| 910 |
+
# Generate chart
|
| 911 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 912 |
+
img_path = Path(temp_file.name)
|
| 913 |
+
try:
|
| 914 |
+
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 915 |
+
if execute_chart_spec(chart_spec, df, img_path):
|
| 916 |
+
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 917 |
+
blob = bucket.blob(blob_name)
|
| 918 |
+
blob.upload_from_filename(str(img_path))
|
| 919 |
+
|
| 920 |
+
chart_urls[safe_desc] = blob.public_url
|
| 921 |
+
logging.info(f"Generated autonomous chart: {safe_desc}")
|
| 922 |
+
finally:
|
| 923 |
+
if os.path.exists(img_path):
|
| 924 |
+
os.unlink(img_path)
|
| 925 |
+
|
| 926 |
+
except Exception as e:
|
| 927 |
+
logging.error(f"Failed to generate chart '{desc}': {str(e)}")
|
| 928 |
+
continue
|
| 929 |
+
|
| 930 |
+
return chart_urls
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def generate_intelligent_chart_suggestions(df: pd.DataFrame, llm) -> List[str]:
|
| 934 |
+
"""
|
| 935 |
+
Generates intelligent chart suggestions based on data characteristics.
|
| 936 |
+
"""
|
| 937 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 938 |
+
categorical_cols = df.select_dtypes(include=['object']).columns
|
| 939 |
+
|
| 940 |
+
suggestions = []
|
| 941 |
+
|
| 942 |
+
# Time series chart if temporal data exists
|
| 943 |
+
if detect_time_series(df):
|
| 944 |
+
suggestions.append("line | Time series trend analysis | Show temporal patterns")
|
| 945 |
+
|
| 946 |
+
# Distribution chart for numeric data
|
| 947 |
+
if len(numeric_cols) > 0:
|
| 948 |
+
main_numeric = numeric_cols[0]
|
| 949 |
+
suggestions.append(f"hist | Distribution of {main_numeric} | Understand data distribution")
|
| 950 |
+
|
| 951 |
+
# Correlation analysis if multiple numeric columns
|
| 952 |
+
if len(numeric_cols) > 1:
|
| 953 |
+
suggestions.append("scatter | Correlation analysis | Identify relationships between variables")
|
| 954 |
+
|
| 955 |
+
# Categorical breakdown
|
| 956 |
+
if len(categorical_cols) > 0:
|
| 957 |
+
main_categorical = categorical_cols[0]
|
| 958 |
+
suggestions.append(f"bar | {main_categorical} breakdown | Show categorical distribution")
|
| 959 |
+
|
| 960 |
+
return suggestions[:MAX_CHARTS]
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
# Helper functions (preserve existing functionality)
|
| 964 |
+
def detect_time_series(df: pd.DataFrame) -> bool:
|
| 965 |
+
"""Detect if dataset contains time series data."""
|
| 966 |
+
for col in df.columns:
|
| 967 |
+
if 'date' in col.lower() or 'time' in col.lower():
|
| 968 |
+
return True
|
| 969 |
+
try:
|
| 970 |
+
pd.to_datetime(df[col])
|
| 971 |
+
return True
|
| 972 |
+
except:
|
| 973 |
+
continue
|
| 974 |
+
return False
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
def detect_transactional_data(df: pd.DataFrame) -> bool:
|
| 978 |
+
"""Detect if dataset contains transactional data."""
|
| 979 |
+
transaction_indicators = ['transaction', 'payment', 'order', 'invoice', 'amount', 'quantity']
|
| 980 |
+
columns_lower = [col.lower() for col in df.columns]
|
| 981 |
+
return any(indicator in col for col in columns_lower for indicator in transaction_indicators)
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
def detect_experimental_data(df: pd.DataFrame) -> bool:
|
| 985 |
+
"""Detect if dataset contains experimental data."""
|
| 986 |
+
experimental_indicators = ['test', 'experiment', 'trial', 'group', 'treatment', 'control']
|
| 987 |
+
columns_lower = [col.lower() for col in df.columns]
|
| 988 |
+
return any(indicator in col for col in columns_lower for indicator in experimental_indicators)
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
def detect_temporal_frequency(date_series: pd.Series) -> str:
|
| 992 |
+
"""Detect the frequency of temporal data."""
|
| 993 |
+
if len(date_series) < 2:
|
| 994 |
+
return "insufficient_data"
|
| 995 |
+
|
| 996 |
+
# Calculate time differences
|
| 997 |
+
time_diffs = date_series.sort_values().diff().dropna()
|
| 998 |
+
median_diff = time_diffs.median()
|
| 999 |
+
|
| 1000 |
+
if median_diff <= pd.Timedelta(days=1):
|
| 1001 |
+
return "daily"
|
| 1002 |
+
elif median_diff <= pd.Timedelta(days=7):
|
| 1003 |
+
return "weekly"
|
| 1004 |
+
elif median_diff <= pd.Timedelta(days=31):
|
| 1005 |
+
return "monthly"
|
| 1006 |
+
else:
|
| 1007 |
+
return "irregular"
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
def determine_analysis_complexity(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> str:
|
| 1011 |
+
"""Determine the complexity level of analysis required."""
|
| 1012 |
+
complexity_factors = 0
|
| 1013 |
+
|
| 1014 |
+
# Data size factor
|
| 1015 |
+
if len(df) > 10000:
|
| 1016 |
+
complexity_factors += 1
|
| 1017 |
+
if len(df.columns) > 20:
|
| 1018 |
+
complexity_factors += 1
|
| 1019 |
+
|
| 1020 |
+
# Data type diversity
|
| 1021 |
+
if len(df.select_dtypes(include=[np.number]).columns) > 5:
|
| 1022 |
+
complexity_factors += 1
|
| 1023 |
+
if len(df.select_dtypes(include=['object']).columns) > 5:
|
| 1024 |
+
complexity_factors += 1
|
| 1025 |
+
|
| 1026 |
+
# Domain complexity
|
| 1027 |
+
if domain_analysis["primary_domain"] in ["scientific", "financial"]:
|
| 1028 |
+
complexity_factors += 1
|
| 1029 |
+
|
| 1030 |
+
if complexity_factors >= 3:
|
| 1031 |
+
return "high"
|
| 1032 |
+
elif complexity_factors >= 2:
|
| 1033 |
+
return "medium"
|
| 1034 |
+
else:
|
| 1035 |
+
return "low"
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
|
| 1039 |
+
"""
|
| 1040 |
+
Fallback to original report generation logic if enhanced version fails.
|
| 1041 |
+
"""
|
| 1042 |
+
logging.info("Using fallback report generation")
|
| 1043 |
+
|
| 1044 |
+
# Original logic preserved
|
| 1045 |
+
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
|
| 1046 |
+
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
| 1047 |
+
|
| 1048 |
+
report_prompt = f"""
|
| 1049 |
+
You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
|
| 1050 |
+
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
|
| 1051 |
+
**Instructions:**
|
| 1052 |
+
1. **Executive Summary**: Start with a high-level summary of key findings.
|
| 1053 |
+
2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
|
| 1054 |
+
3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
|
| 1055 |
+
Valid chart types: bar, pie, line, scatter, hist.
|
| 1056 |
+
Generate insights that would be valuable to C-level executives.
|
| 1057 |
+
"""
|
| 1058 |
+
|
| 1059 |
+
md = llm.invoke(report_prompt).content
|
| 1060 |
+
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 1061 |
+
chart_urls = {}
|
| 1062 |
+
chart_generator = ChartGenerator(llm, df)
|
| 1063 |
+
|
| 1064 |
+
for desc in chart_descs:
|
| 1065 |
+
safe_desc = sanitize_for_firebase_key(desc)
|
| 1066 |
+
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 1067 |
+
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 1068 |
+
|
| 1069 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 1070 |
+
img_path = Path(temp_file.name)
|
| 1071 |
+
try:
|
| 1072 |
+
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 1073 |
+
if execute_chart_spec(chart_spec, df, img_path):
|
| 1074 |
+
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 1075 |
+
blob = bucket.blob(blob_name)
|
| 1076 |
+
blob.upload_from_filename(str(img_path))
|
| 1077 |
+
chart_urls[safe_desc] = blob.public_url
|
| 1078 |
+
finally:
|
| 1079 |
+
if os.path.exists(img_path):
|
| 1080 |
+
os.unlink(img_path)
|
| 1081 |
+
|
| 1082 |
+
return {"raw_md": md, "chartUrls": chart_urls}
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
|
| 1086 |
"""
|
| 1087 |
Generates a basic fallback report when enhanced generation fails.
|