rairo commited on
Commit
e3aa5e7
·
verified ·
1 Parent(s): 78debf6

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +64 -17
sozo_gen.py CHANGED
@@ -29,7 +29,7 @@ import requests
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
30
  FPS, WIDTH, HEIGHT = 24, 1280, 720
31
  MAX_CHARTS, VIDEO_SCENES = 5, 5
32
- MAX_CONTEXT_TOKENS = 250000 # Set max token limit for full dataset context
33
 
34
  # --- API Initialization ---
35
  API_KEY = os.getenv("GOOGLE_API_KEY")
@@ -38,7 +38,7 @@ if not API_KEY:
38
 
39
  PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
40
 
41
- # --- Helper Functions ---
42
  def load_dataframe_safely(buf, name: str):
43
  ext = Path(name).suffix.lower()
44
  df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
@@ -147,7 +147,7 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
147
  temp_dl_path.unlink()
148
  return None
149
 
150
- # --- Chart Generation System ---
151
  class ChartSpecification:
152
  def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
153
  self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
@@ -233,7 +233,7 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
233
  return df[numeric_cols].corr()
234
  return df[spec.x_col]
235
 
236
- # --- Animation & Video Generation ---
237
  def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
238
  plot_data = prepare_plot_data(spec, df)
239
  frames = max(10, int(dur * fps))
@@ -368,6 +368,42 @@ def sanitize_for_firebase_key(text: str) -> str:
368
  text = text.replace(char, '_')
369
  return text
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
372
  """Creates a detailed summary of the dataframe for the AI."""
373
  numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
@@ -400,41 +436,52 @@ def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
400
  return json.loads(json.dumps(context, default=str))
401
 
402
  def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
403
- logging.info(f"Generating report draft for project {project_id}")
404
  df = load_dataframe_safely(buf, name)
405
- llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
406
 
 
407
  data_context_str = ""
408
  context_for_charts = {}
409
  try:
410
  df_json = df.to_json(orient='records')
411
  estimated_tokens = len(df_json) / 4
412
  if estimated_tokens < MAX_CONTEXT_TOKENS:
413
- logging.info(f"Dataset is small enough ({estimated_tokens:.0f} tokens). Using full JSON context.")
414
  data_context_str = f"Here is the full dataset in JSON format:\n{df_json}"
415
  context_for_charts = get_augmented_context(df, ctx)
416
  else:
417
- raise ValueError("Dataset too large for full context.")
418
  except Exception as e:
419
- logging.warning(f"Could not use full JSON context ({e}). Falling back to augmented summary.")
420
  augmented_context = get_augmented_context(df, ctx)
421
  data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}"
422
  context_for_charts = augmented_context
423
 
 
 
 
 
424
  report_prompt = f"""
425
- You are an expert data analyst and business intelligence storyteller. Your mission is to analyze the provided data context and write a comprehensive, executive-level report in Markdown format.
426
 
427
  **Data Context:**
428
  {data_context_str}
429
 
430
- **Critical Instructions:**
431
- 1. **Data Grounding:** Your entire analysis and narrative **must strictly** use the column names and data provided in the 'Data Context' section. Do not invent, modify, or assume any column names that are not on this list. This is the most important rule.
432
- 2. **Report Goal:** Create a well-structured, professional report in Markdown that tells a compelling story from the data. The structure of the report is entirely up to you, but it should be logical and easy to follow.
433
- 3. **Visual Support:** Wherever a key finding, trend, or significant point is made in your narrative, you **must** support it with a chart tag using the format: `<generate_chart: "chart_type | a specific, compelling description">`.
434
- 4. **Chart Tag Grounding:** The column names used in your chart descriptions **must** also be an exact match from the provided data context.
435
- 5. **Available Chart Types:** `bar, pie, line, scatter, hist, heatmap, area, bubble`.
 
 
 
 
 
 
436
 
437
- Now, generate the complete Markdown report.
438
  """
439
 
440
  md = llm.invoke(report_prompt).content
 
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
30
  FPS, WIDTH, HEIGHT = 24, 1280, 720
31
  MAX_CHARTS, VIDEO_SCENES = 5, 5
32
+ MAX_CONTEXT_TOKENS = 500000
33
 
34
  # --- API Initialization ---
35
  API_KEY = os.getenv("GOOGLE_API_KEY")
 
38
 
39
  PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
40
 
41
+ # --- Helper Functions (Stable) ---
42
  def load_dataframe_safely(buf, name: str):
43
  ext = Path(name).suffix.lower()
44
  df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
 
147
  temp_dl_path.unlink()
148
  return None
149
 
150
+ # --- Chart Generation System (Stable) ---
151
  class ChartSpecification:
152
  def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
153
  self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
 
233
  return df[numeric_cols].corr()
234
  return df[spec.x_col]
235
 
236
+ # --- Animation & Video Generation (Stable) ---
237
  def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
238
  plot_data = prepare_plot_data(spec, df)
239
  frames = max(10, int(dur * fps))
 
368
  text = text.replace(char, '_')
369
  return text
370
 
371
+ # NEW: Intelligence functions to guide the storyteller AI
372
+ def analyze_data_intelligence(df: pd.DataFrame) -> Dict:
373
+ """Analyzes the dataset to find key characteristics and opportunities for storytelling."""
374
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
375
+ categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
376
+
377
+ is_timeseries = any('date' in col.lower() or 'time' in col.lower() for col in df.columns)
378
+
379
+ opportunities = []
380
+ if is_timeseries:
381
+ opportunities.append("temporal trends")
382
+ if len(numeric_cols) > 1:
383
+ opportunities.append("correlations between metrics")
384
+ if len(categorical_cols) > 0 and len(numeric_cols) > 0:
385
+ opportunities.append("segmentation by category")
386
+ if df.isnull().sum().sum() > 0:
387
+ opportunities.append("impact of missing data")
388
+
389
+ return {
390
+ "insight_opportunities": opportunities,
391
+ "is_timeseries": is_timeseries,
392
+ "has_correlations": len(numeric_cols) > 1,
393
+ "has_segments": len(categorical_cols) > 0 and len(numeric_cols) > 0
394
+ }
395
+
396
+ def generate_visualization_strategy(intelligence: Dict) -> str:
397
+ """Generates dynamic advice on which charts to use."""
398
+ strategy = "Vary your visualizations to keep the report engaging. "
399
+ if intelligence["is_timeseries"]:
400
+ strategy += "Use 'line' or 'area' charts to explore temporal trends. "
401
+ if intelligence["has_correlations"]:
402
+ strategy += "Use 'scatter' or 'heatmap' charts to reveal correlations. "
403
+ if intelligence["has_segments"]:
404
+ strategy += "Use 'bar' or 'pie' charts to compare segments. "
405
+ return strategy
406
+
407
  def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
408
  """Creates a detailed summary of the dataframe for the AI."""
409
  numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
 
436
  return json.loads(json.dumps(context, default=str))
437
 
438
  def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
439
+ logging.info(f"Generating persona-driven report draft for project {project_id}")
440
  df = load_dataframe_safely(buf, name)
441
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
442
 
443
+ # --- Try/Fallback Context Strategy ---
444
  data_context_str = ""
445
  context_for_charts = {}
446
  try:
447
  df_json = df.to_json(orient='records')
448
  estimated_tokens = len(df_json) / 4
449
  if estimated_tokens < MAX_CONTEXT_TOKENS:
450
+ logging.info(f"Using full JSON context.")
451
  data_context_str = f"Here is the full dataset in JSON format:\n{df_json}"
452
  context_for_charts = get_augmented_context(df, ctx)
453
  else:
454
+ raise ValueError("Dataset too large.")
455
  except Exception as e:
456
+ logging.warning(f"Falling back to augmented summary context: {e}")
457
  augmented_context = get_augmented_context(df, ctx)
458
  data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}"
459
  context_for_charts = augmented_context
460
 
461
+ # --- Persona-Driven Prompting ---
462
+ intelligence = analyze_data_intelligence(df)
463
+ viz_strategy = generate_visualization_strategy(intelligence)
464
+
465
  report_prompt = f"""
466
+ You are an elite data storyteller and business intelligence expert. Your mission is to uncover the compelling, hidden narrative in this dataset and present it as a captivating story in Markdown format that drives action.
467
 
468
  **Data Context:**
469
  {data_context_str}
470
 
471
+ **Intelligence Analysis:**
472
+ - The most interesting parts of this story may lie in the following areas: {', '.join(intelligence['insight_opportunities'])}.
473
+ - Weave these threads into your core narrative.
474
+
475
+ **Visualization Strategy:**
476
+ - {viz_strategy}
477
+ - Available Chart Types: `bar, pie, line, scatter, hist, heatmap, area, bubble`.
478
+
479
+ **Your Grounding Rules (Most Important):**
480
+ 1. **Strict Accuracy:** Your entire analysis and narrative **must strictly** use the column names provided in the 'Data Context' section. Do not invent, modify, or assume any column names that are not on this list.
481
+ 2. **Chart Support:** Wherever a key finding is made, you **must** support it with a chart tag: `<generate_chart: "chart_type | a specific, compelling description">`.
482
+ 3. **Chart Accuracy:** The column names used in your chart descriptions **must** also be an exact match from the provided data context.
483
 
484
+ Now, begin your report. Let the data's story unfold naturally.
485
  """
486
 
487
  md = llm.invoke(report_prompt).content