rairo commited on
Commit
2107a21
·
verified ·
1 Parent(s): 3d8622b

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +271 -28
sozo_gen.py CHANGED
@@ -310,11 +310,11 @@ def perform_autonomous_data_analysis(df: pd.DataFrame, user_ctx: str, filename:
310
  """
311
  logging.info("Performing autonomous data analysis...")
312
 
313
- # Basic data profiling
314
  basic_info = {
315
  "shape": df.shape,
316
  "columns": list(df.columns),
317
- "dtypes": df.dtypes.to_dict(),
318
  "filename": filename,
319
  "user_context": user_ctx
320
  }
@@ -388,11 +388,11 @@ def classify_dataset_domain(df: pd.DataFrame, filename: str) -> Dict[str, Any]:
388
 
389
  return {
390
  "primary_domain": primary_domain,
391
- "domain_confidence": domain_scores.get(primary_domain, 0),
392
- "domain_scores": domain_scores,
393
  "data_characteristics": {
394
- "numeric_ratio": numeric_ratio,
395
- "categorical_ratio": categorical_ratio,
396
  "is_time_series": detect_time_series(df),
397
  "is_transactional": detect_transactional_data(df),
398
  "is_experimental": detect_experimental_data(df)
@@ -412,12 +412,20 @@ def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
412
  "missing_data": {}
413
  }
414
 
415
- # Summary statistics for numeric columns
416
  numeric_cols = df.select_dtypes(include=[np.number]).columns
417
  if len(numeric_cols) > 0:
418
- profile["summary_stats"] = df[numeric_cols].describe().to_dict()
 
 
 
 
 
 
 
 
419
 
420
- # Correlation analysis
421
  if len(numeric_cols) > 1:
422
  corr_matrix = df[numeric_cols].corr()
423
  # Find strong correlations
@@ -425,11 +433,11 @@ def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
425
  for i in range(len(corr_matrix.columns)):
426
  for j in range(i+1, len(corr_matrix.columns)):
427
  corr_val = corr_matrix.iloc[i, j]
428
- if abs(corr_val) > 0.7: # Strong correlation threshold
429
  strong_corrs.append({
430
  "var1": corr_matrix.columns[i],
431
  "var2": corr_matrix.columns[j],
432
- "correlation": corr_val
433
  })
434
  profile["correlations"] = {"strong_correlations": strong_corrs}
435
 
@@ -438,16 +446,22 @@ def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
438
  if len(categorical_cols) > 0:
439
  profile["categorical_analysis"] = {}
440
  for col in categorical_cols:
 
441
  profile["categorical_analysis"][col] = {
442
- "unique_count": df[col].nunique(),
443
- "top_values": df[col].value_counts().head(5).to_dict()
444
  }
445
 
446
- # Missing data analysis
447
  missing_data = df.isnull().sum()
 
 
 
 
 
448
  profile["missing_data"] = {
449
- "columns_with_missing": missing_data[missing_data > 0].to_dict(),
450
- "total_missing_percentage": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100
451
  }
452
 
453
  return profile
@@ -472,11 +486,11 @@ def discover_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
472
  for col2 in numeric_cols:
473
  if col1 != col2:
474
  correlation = df[col1].corr(df[col2])
475
- if abs(correlation) > 0.5: # Moderate to strong correlation
476
  relationships["key_relationships"].append({
477
  "variable1": col1,
478
  "variable2": col2,
479
- "relationship_strength": correlation,
480
  "relationship_type": "positive" if correlation > 0 else "negative"
481
  })
482
 
@@ -489,8 +503,8 @@ def discover_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
489
  relationships["patterns"].append({
490
  "column": col,
491
  "pattern_type": "categorical_distribution",
492
- "dominant_category": value_counts.index[0],
493
- "dominance_percentage": (value_counts.iloc[0] / len(df)) * 100
494
  })
495
 
496
  return relationships
@@ -526,7 +540,7 @@ def analyze_temporal_patterns(df: pd.DataFrame) -> Dict[str, Any]:
526
  "start": df_temp[primary_date_col].min().strftime('%Y-%m-%d'),
527
  "end": df_temp[primary_date_col].max().strftime('%Y-%m-%d')
528
  },
529
- "time_span_days": (df_temp[primary_date_col].max() - df_temp[primary_date_col].min()).days,
530
  "frequency": detect_temporal_frequency(df_temp[primary_date_col])
531
  }
532
 
@@ -544,16 +558,16 @@ def assess_data_quality(df: pd.DataFrame) -> Dict[str, Any]:
544
  "data_consistency": {}
545
  }
546
 
547
- # Completeness assessment
548
- completeness = (1 - df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100
549
  quality_metrics["data_completeness"] = completeness
550
 
551
  # Identify quality issues
552
  if completeness < 95:
553
  quality_metrics["quality_issues"].append("Missing data detected")
554
 
555
- # Check for duplicates
556
- duplicate_rows = df.duplicated().sum()
557
  if duplicate_rows > 0:
558
  quality_metrics["quality_issues"].append(f"{duplicate_rows} duplicate rows found")
559
 
@@ -563,11 +577,11 @@ def assess_data_quality(df: pd.DataFrame) -> Dict[str, Any]:
563
  if df[col].str.isnumeric().any() and not df[col].str.isnumeric().all():
564
  quality_metrics["quality_issues"].append(f"Inconsistent data types in {col}")
565
 
566
- # Calculate overall quality score
567
- base_score = 100
568
  base_score -= (100 - completeness) * 0.5 # Penalize missing data
569
  base_score -= len(quality_metrics["quality_issues"]) * 5 # Penalize each quality issue
570
- quality_metrics["overall_quality_score"] = max(0, base_score)
571
 
572
  return quality_metrics
573
 
@@ -839,6 +853,235 @@ def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_
839
  return {"raw_md": md, "chartUrls": chart_urls}
840
 
841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
843
  """
844
  Generates a basic fallback report when enhanced generation fails.
 
310
  """
311
  logging.info("Performing autonomous data analysis...")
312
 
313
+ # Basic data profiling with JSON-safe types
314
  basic_info = {
315
  "shape": df.shape,
316
  "columns": list(df.columns),
317
+ "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
318
  "filename": filename,
319
  "user_context": user_ctx
320
  }
 
388
 
389
  return {
390
  "primary_domain": primary_domain,
391
+ "domain_confidence": int(domain_scores.get(primary_domain, 0)),
392
+ "domain_scores": {k: int(v) for k, v in domain_scores.items()},
393
  "data_characteristics": {
394
+ "numeric_ratio": float(numeric_ratio),
395
+ "categorical_ratio": float(categorical_ratio),
396
  "is_time_series": detect_time_series(df),
397
  "is_transactional": detect_transactional_data(df),
398
  "is_experimental": detect_experimental_data(df)
 
412
  "missing_data": {}
413
  }
414
 
415
+ # Summary statistics for numeric columns with JSON-safe conversion
416
  numeric_cols = df.select_dtypes(include=[np.number]).columns
417
  if len(numeric_cols) > 0:
418
+ desc_stats = df[numeric_cols].describe()
419
+ # Convert to JSON-safe format
420
+ profile["summary_stats"] = {
421
+ col: {
422
+ stat: float(val) if pd.notna(val) else None
423
+ for stat, val in desc_stats[col].items()
424
+ }
425
+ for col in desc_stats.columns
426
+ }
427
 
428
+ # Correlation analysis with JSON-safe conversion
429
  if len(numeric_cols) > 1:
430
  corr_matrix = df[numeric_cols].corr()
431
  # Find strong correlations
 
433
  for i in range(len(corr_matrix.columns)):
434
  for j in range(i+1, len(corr_matrix.columns)):
435
  corr_val = corr_matrix.iloc[i, j]
436
+ if abs(corr_val) > 0.7 and pd.notna(corr_val): # Strong correlation threshold
437
  strong_corrs.append({
438
  "var1": corr_matrix.columns[i],
439
  "var2": corr_matrix.columns[j],
440
+ "correlation": float(corr_val)
441
  })
442
  profile["correlations"] = {"strong_correlations": strong_corrs}
443
 
 
446
  if len(categorical_cols) > 0:
447
  profile["categorical_analysis"] = {}
448
  for col in categorical_cols:
449
+ value_counts = df[col].value_counts().head(5)
450
  profile["categorical_analysis"][col] = {
451
+ "unique_count": int(df[col].nunique()),
452
+ "top_values": {str(k): int(v) for k, v in value_counts.items()}
453
  }
454
 
455
+ # Missing data analysis with JSON-safe conversion
456
  missing_data = df.isnull().sum()
457
+ missing_dict = {}
458
+ for col, missing_count in missing_data.items():
459
+ if missing_count > 0:
460
+ missing_dict[col] = int(missing_count)
461
+
462
  profile["missing_data"] = {
463
+ "columns_with_missing": missing_dict,
464
+ "total_missing_percentage": float((df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100)
465
  }
466
 
467
  return profile
 
486
  for col2 in numeric_cols:
487
  if col1 != col2:
488
  correlation = df[col1].corr(df[col2])
489
+ if abs(correlation) > 0.5 and pd.notna(correlation): # Moderate to strong correlation
490
  relationships["key_relationships"].append({
491
  "variable1": col1,
492
  "variable2": col2,
493
+ "relationship_strength": float(correlation),
494
  "relationship_type": "positive" if correlation > 0 else "negative"
495
  })
496
 
 
503
  relationships["patterns"].append({
504
  "column": col,
505
  "pattern_type": "categorical_distribution",
506
+ "dominant_category": str(value_counts.index[0]),
507
+ "dominance_percentage": float((value_counts.iloc[0] / len(df)) * 100)
508
  })
509
 
510
  return relationships
 
540
  "start": df_temp[primary_date_col].min().strftime('%Y-%m-%d'),
541
  "end": df_temp[primary_date_col].max().strftime('%Y-%m-%d')
542
  },
543
+ "time_span_days": int((df_temp[primary_date_col].max() - df_temp[primary_date_col].min()).days),
544
  "frequency": detect_temporal_frequency(df_temp[primary_date_col])
545
  }
546
 
 
558
  "data_consistency": {}
559
  }
560
 
561
+ # Completeness assessment with JSON-safe conversion
562
+ completeness = float((1 - df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100)
563
  quality_metrics["data_completeness"] = completeness
564
 
565
  # Identify quality issues
566
  if completeness < 95:
567
  quality_metrics["quality_issues"].append("Missing data detected")
568
 
569
+ # Check for duplicates with JSON-safe conversion
570
+ duplicate_rows = int(df.duplicated().sum())
571
  if duplicate_rows > 0:
572
  quality_metrics["quality_issues"].append(f"{duplicate_rows} duplicate rows found")
573
 
 
577
  if df[col].str.isnumeric().any() and not df[col].str.isnumeric().all():
578
  quality_metrics["quality_issues"].append(f"Inconsistent data types in {col}")
579
 
580
+ # Calculate overall quality score with JSON-safe conversion
581
+ base_score = 100.0
582
  base_score -= (100 - completeness) * 0.5 # Penalize missing data
583
  base_score -= len(quality_metrics["quality_issues"]) * 5 # Penalize each quality issue
584
+ quality_metrics["overall_quality_score"] = float(max(0, base_score))
585
 
586
  return quality_metrics
587
 
 
853
  return {"raw_md": md, "chartUrls": chart_urls}
854
 
855
 
856
+ def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
857
+ """
858
+ Generates a basic fallback report when enhanced generation fails.
859
+ """
860
+ basic_info = autonomous_context["basic_info"]
861
+ domain = autonomous_context["domain"]["primary_domain"]
862
+
863
+ return f"""
864
+ # What This Data Reveals
865
+
866
+ Looking at this {domain} dataset with {basic_info['shape'][0]} records, there are several key insights worth highlighting.
867
+
868
+ ## The Numbers Tell a Story
869
+
870
+ This dataset contains {basic_info['shape'][1]} different variables, suggesting a comprehensive view of the underlying processes or behaviors being measured.
871
+
872
+ <generate_chart: "bar | Data overview showing key metrics">
873
+
874
+ ## What You Should Know
875
+
876
+ The data structure and patterns suggest this is worth deeper investigation. The variety of data types and relationships indicate multiple analytical opportunities.
877
+
878
+ ## Next Steps
879
+
880
+ Based on this initial analysis, I recommend diving deeper into the specific patterns and relationships within the data to unlock more actionable insights.
881
+
882
+ *Note: This is a simplified analysis. Enhanced storytelling temporarily unavailable.*
883
+ """
884
+ # Removed - no longer needed since we're letting AI decide everything organically
885
+
886
+
887
+ def generate_autonomous_charts(llm, df: pd.DataFrame, report_md: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
888
+ """
889
+ Generates charts autonomously based on the report content and data characteristics.
890
+ """
891
+ # Extract chart descriptions from the enhanced report
892
+ chart_descs = extract_chart_tags(report_md)[:MAX_CHARTS]
893
+ chart_urls = {}
894
+
895
+ if not chart_descs:
896
+ # If no charts specified, generate intelligent defaults
897
+ chart_descs = generate_intelligent_chart_suggestions(df, llm)
898
+
899
+ chart_generator = ChartGenerator(llm, df)
900
+
901
+ for desc in chart_descs:
902
+ try:
903
+ # Create a safe key for Firebase
904
+ safe_desc = sanitize_for_firebase_key(desc)
905
+
906
+ # Replace chart tags in markdown
907
+ report_md = report_md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
908
+ report_md = report_md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
909
+
910
+ # Generate chart
911
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
912
+ img_path = Path(temp_file.name)
913
+ try:
914
+ chart_spec = chart_generator.generate_chart_spec(desc)
915
+ if execute_chart_spec(chart_spec, df, img_path):
916
+ blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
917
+ blob = bucket.blob(blob_name)
918
+ blob.upload_from_filename(str(img_path))
919
+
920
+ chart_urls[safe_desc] = blob.public_url
921
+ logging.info(f"Generated autonomous chart: {safe_desc}")
922
+ finally:
923
+ if os.path.exists(img_path):
924
+ os.unlink(img_path)
925
+
926
+ except Exception as e:
927
+ logging.error(f"Failed to generate chart '{desc}': {str(e)}")
928
+ continue
929
+
930
+ return chart_urls
931
+
932
+
933
+ def generate_intelligent_chart_suggestions(df: pd.DataFrame, llm) -> List[str]:
934
+ """
935
+ Generates intelligent chart suggestions based on data characteristics.
936
+ """
937
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
938
+ categorical_cols = df.select_dtypes(include=['object']).columns
939
+
940
+ suggestions = []
941
+
942
+ # Time series chart if temporal data exists
943
+ if detect_time_series(df):
944
+ suggestions.append("line | Time series trend analysis | Show temporal patterns")
945
+
946
+ # Distribution chart for numeric data
947
+ if len(numeric_cols) > 0:
948
+ main_numeric = numeric_cols[0]
949
+ suggestions.append(f"hist | Distribution of {main_numeric} | Understand data distribution")
950
+
951
+ # Correlation analysis if multiple numeric columns
952
+ if len(numeric_cols) > 1:
953
+ suggestions.append("scatter | Correlation analysis | Identify relationships between variables")
954
+
955
+ # Categorical breakdown
956
+ if len(categorical_cols) > 0:
957
+ main_categorical = categorical_cols[0]
958
+ suggestions.append(f"bar | {main_categorical} breakdown | Show categorical distribution")
959
+
960
+ return suggestions[:MAX_CHARTS]
961
+
962
+
963
+ # Helper functions (preserve existing functionality)
964
+ def detect_time_series(df: pd.DataFrame) -> bool:
965
+ """Detect if dataset contains time series data."""
966
+ for col in df.columns:
967
+ if 'date' in col.lower() or 'time' in col.lower():
968
+ return True
969
+ try:
970
+ pd.to_datetime(df[col])
971
+ return True
972
+ except:
973
+ continue
974
+ return False
975
+
976
+
977
+ def detect_transactional_data(df: pd.DataFrame) -> bool:
978
+ """Detect if dataset contains transactional data."""
979
+ transaction_indicators = ['transaction', 'payment', 'order', 'invoice', 'amount', 'quantity']
980
+ columns_lower = [col.lower() for col in df.columns]
981
+ return any(indicator in col for col in columns_lower for indicator in transaction_indicators)
982
+
983
+
984
+ def detect_experimental_data(df: pd.DataFrame) -> bool:
985
+ """Detect if dataset contains experimental data."""
986
+ experimental_indicators = ['test', 'experiment', 'trial', 'group', 'treatment', 'control']
987
+ columns_lower = [col.lower() for col in df.columns]
988
+ return any(indicator in col for col in columns_lower for indicator in experimental_indicators)
989
+
990
+
991
+ def detect_temporal_frequency(date_series: pd.Series) -> str:
992
+ """Detect the frequency of temporal data."""
993
+ if len(date_series) < 2:
994
+ return "insufficient_data"
995
+
996
+ # Calculate time differences
997
+ time_diffs = date_series.sort_values().diff().dropna()
998
+ median_diff = time_diffs.median()
999
+
1000
+ if median_diff <= pd.Timedelta(days=1):
1001
+ return "daily"
1002
+ elif median_diff <= pd.Timedelta(days=7):
1003
+ return "weekly"
1004
+ elif median_diff <= pd.Timedelta(days=31):
1005
+ return "monthly"
1006
+ else:
1007
+ return "irregular"
1008
+
1009
+
1010
+ def determine_analysis_complexity(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> str:
1011
+ """Determine the complexity level of analysis required."""
1012
+ complexity_factors = 0
1013
+
1014
+ # Data size factor
1015
+ if len(df) > 10000:
1016
+ complexity_factors += 1
1017
+ if len(df.columns) > 20:
1018
+ complexity_factors += 1
1019
+
1020
+ # Data type diversity
1021
+ if len(df.select_dtypes(include=[np.number]).columns) > 5:
1022
+ complexity_factors += 1
1023
+ if len(df.select_dtypes(include=['object']).columns) > 5:
1024
+ complexity_factors += 1
1025
+
1026
+ # Domain complexity
1027
+ if domain_analysis["primary_domain"] in ["scientific", "financial"]:
1028
+ complexity_factors += 1
1029
+
1030
+ if complexity_factors >= 3:
1031
+ return "high"
1032
+ elif complexity_factors >= 2:
1033
+ return "medium"
1034
+ else:
1035
+ return "low"
1036
+
1037
+
1038
+ def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
1039
+ """
1040
+ Fallback to original report generation logic if enhanced version fails.
1041
+ """
1042
+ logging.info("Using fallback report generation")
1043
+
1044
+ # Original logic preserved
1045
+ ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
1046
+ enhanced_ctx = enhance_data_context(df, ctx_dict)
1047
+
1048
+ report_prompt = f"""
1049
+ You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
1050
+ **Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
1051
+ **Instructions:**
1052
+ 1. **Executive Summary**: Start with a high-level summary of key findings.
1053
+ 2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
1054
+ 3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
1055
+ Valid chart types: bar, pie, line, scatter, hist.
1056
+ Generate insights that would be valuable to C-level executives.
1057
+ """
1058
+
1059
+ md = llm.invoke(report_prompt).content
1060
+ chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
1061
+ chart_urls = {}
1062
+ chart_generator = ChartGenerator(llm, df)
1063
+
1064
+ for desc in chart_descs:
1065
+ safe_desc = sanitize_for_firebase_key(desc)
1066
+ md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
1067
+ md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
1068
+
1069
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
1070
+ img_path = Path(temp_file.name)
1071
+ try:
1072
+ chart_spec = chart_generator.generate_chart_spec(desc)
1073
+ if execute_chart_spec(chart_spec, df, img_path):
1074
+ blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
1075
+ blob = bucket.blob(blob_name)
1076
+ blob.upload_from_filename(str(img_path))
1077
+ chart_urls[safe_desc] = blob.public_url
1078
+ finally:
1079
+ if os.path.exists(img_path):
1080
+ os.unlink(img_path)
1081
+
1082
+ return {"raw_md": md, "chartUrls": chart_urls}
1083
+
1084
+
1085
  def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
1086
  """
1087
  Generates a basic fallback report when enhanced generation fails.