rairo commited on
Commit
7ba591e
·
verified ·
1 Parent(s): 5e18bc5

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +25 -3
sozo_gen.py CHANGED
@@ -86,6 +86,15 @@ def clean_narration(txt: str) -> str:
86
 
87
  def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
88
 
 
 
 
 
 
 
 
 
 
89
 
90
  def detect_dataset_domain(df: pd.DataFrame) -> str:
91
  """Analyzes column names to detect the dataset's primary domain."""
@@ -415,20 +424,33 @@ def generate_visualization_strategy(intelligence: Dict) -> str:
415
  return strategy
416
 
417
  def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
 
418
  numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
419
  categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
 
420
  context = {
421
  "user_context": user_ctx,
422
  "dataset_shape": {"rows": df.shape[0], "columns": df.shape[1]},
423
  "schema": {"numeric_columns": numeric_cols, "categorical_columns": categorical_cols},
424
  "data_previews": {}
425
  }
 
426
  for col in categorical_cols[:5]:
427
  unique_vals = df[col].unique()
428
- context["data_previews"][col] = {"count": len(unique_vals), "values": unique_vals[:5].tolist()}
 
 
 
 
429
  for col in numeric_cols[:5]:
430
- context["data_previews"][col] = {"mean": df[col].mean(), "min": df[col].min(), "max": df[col].max()}
431
- return json.loads(json.dumps(context, default=str))
 
 
 
 
 
 
432
 
433
  def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
434
  logging.info(f"Generating guided storyteller report draft for project {project_id}")
 
86
 
87
  def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
88
 
89
+ def _sanitize_for_json(data):
90
+ """Recursively sanitizes a dict/list for JSON compliance."""
91
+ if isinstance(data, dict):
92
+ return {k: _sanitize_for_json(v) for k, v in data.items()}
93
+ if isinstance(data, list):
94
+ return [_sanitize_for_json(i) for i in data]
95
+ if isinstance(data, float) and (math.isnan(data) or math.isinf(data)):
96
+ return None
97
+ return data
98
 
99
  def detect_dataset_domain(df: pd.DataFrame) -> str:
100
  """Analyzes column names to detect the dataset's primary domain."""
 
424
  return strategy
425
 
426
  def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
427
+ """Creates a detailed, JSON-safe summary of the dataframe for the AI."""
428
  numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
429
  categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
430
+
431
  context = {
432
  "user_context": user_ctx,
433
  "dataset_shape": {"rows": df.shape[0], "columns": df.shape[1]},
434
  "schema": {"numeric_columns": numeric_cols, "categorical_columns": categorical_cols},
435
  "data_previews": {}
436
  }
437
+
438
  for col in categorical_cols[:5]:
439
  unique_vals = df[col].unique()
440
+ context["data_previews"][col] = {
441
+ "count": len(unique_vals),
442
+ "values": unique_vals[:5].tolist()
443
+ }
444
+
445
  for col in numeric_cols[:5]:
446
+ context["data_previews"][col] = {
447
+ "mean": df[col].mean(),
448
+ "min": df[col].min(),
449
+ "max": df[col].max()
450
+ }
451
+
452
+ # Sanitize the entire structure before returning
453
+ return _sanitize_for_json(json.loads(json.dumps(context, default=str)))
454
 
455
  def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
456
  logging.info(f"Generating guided storyteller report draft for project {project_id}")