Update sozo_gen.py
Browse files- sozo_gen.py +64 -17
sozo_gen.py
CHANGED
|
@@ -29,7 +29,7 @@ import requests
|
|
| 29 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
|
| 30 |
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 31 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 32 |
-
MAX_CONTEXT_TOKENS =
|
| 33 |
|
| 34 |
# --- API Initialization ---
|
| 35 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
|
@@ -38,7 +38,7 @@ if not API_KEY:
|
|
| 38 |
|
| 39 |
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
|
| 40 |
|
| 41 |
-
# --- Helper Functions ---
|
| 42 |
def load_dataframe_safely(buf, name: str):
|
| 43 |
ext = Path(name).suffix.lower()
|
| 44 |
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
|
|
@@ -147,7 +147,7 @@ def search_and_download_pexels_video(query: str, duration: float, out_path: Path
|
|
| 147 |
temp_dl_path.unlink()
|
| 148 |
return None
|
| 149 |
|
| 150 |
-
# --- Chart Generation System ---
|
| 151 |
class ChartSpecification:
|
| 152 |
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
| 153 |
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
|
|
@@ -233,7 +233,7 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
|
|
| 233 |
return df[numeric_cols].corr()
|
| 234 |
return df[spec.x_col]
|
| 235 |
|
| 236 |
-
# --- Animation & Video Generation ---
|
| 237 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 238 |
plot_data = prepare_plot_data(spec, df)
|
| 239 |
frames = max(10, int(dur * fps))
|
|
@@ -368,6 +368,42 @@ def sanitize_for_firebase_key(text: str) -> str:
|
|
| 368 |
text = text.replace(char, '_')
|
| 369 |
return text
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
|
| 372 |
"""Creates a detailed summary of the dataframe for the AI."""
|
| 373 |
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
|
@@ -400,41 +436,52 @@ def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
|
|
| 400 |
return json.loads(json.dumps(context, default=str))
|
| 401 |
|
| 402 |
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
|
| 403 |
-
logging.info(f"Generating report draft for project {project_id}")
|
| 404 |
df = load_dataframe_safely(buf, name)
|
| 405 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.
|
| 406 |
|
|
|
|
| 407 |
data_context_str = ""
|
| 408 |
context_for_charts = {}
|
| 409 |
try:
|
| 410 |
df_json = df.to_json(orient='records')
|
| 411 |
estimated_tokens = len(df_json) / 4
|
| 412 |
if estimated_tokens < MAX_CONTEXT_TOKENS:
|
| 413 |
-
logging.info(f"
|
| 414 |
data_context_str = f"Here is the full dataset in JSON format:\n{df_json}"
|
| 415 |
context_for_charts = get_augmented_context(df, ctx)
|
| 416 |
else:
|
| 417 |
-
raise ValueError("Dataset too large
|
| 418 |
except Exception as e:
|
| 419 |
-
logging.warning(f"
|
| 420 |
augmented_context = get_augmented_context(df, ctx)
|
| 421 |
data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}"
|
| 422 |
context_for_charts = augmented_context
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
report_prompt = f"""
|
| 425 |
-
You are an
|
| 426 |
|
| 427 |
**Data Context:**
|
| 428 |
{data_context_str}
|
| 429 |
|
| 430 |
-
**
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
-
Now,
|
| 438 |
"""
|
| 439 |
|
| 440 |
md = llm.invoke(report_prompt).content
|
|
|
|
| 29 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
|
| 30 |
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 31 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 32 |
+
MAX_CONTEXT_TOKENS = 500000
|
| 33 |
|
| 34 |
# --- API Initialization ---
|
| 35 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
|
|
|
| 38 |
|
| 39 |
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
|
| 40 |
|
| 41 |
+
# --- Helper Functions (Stable) ---
|
| 42 |
def load_dataframe_safely(buf, name: str):
|
| 43 |
ext = Path(name).suffix.lower()
|
| 44 |
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
|
|
|
|
| 147 |
temp_dl_path.unlink()
|
| 148 |
return None
|
| 149 |
|
| 150 |
+
# --- Chart Generation System (Stable) ---
|
| 151 |
class ChartSpecification:
|
| 152 |
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
| 153 |
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
|
|
|
|
| 233 |
return df[numeric_cols].corr()
|
| 234 |
return df[spec.x_col]
|
| 235 |
|
| 236 |
+
# --- Animation & Video Generation (Stable) ---
|
| 237 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 238 |
plot_data = prepare_plot_data(spec, df)
|
| 239 |
frames = max(10, int(dur * fps))
|
|
|
|
| 368 |
text = text.replace(char, '_')
|
| 369 |
return text
|
| 370 |
|
| 371 |
+
# NEW: Intelligence functions to guide the storyteller AI
|
| 372 |
+
def analyze_data_intelligence(df: pd.DataFrame) -> Dict:
|
| 373 |
+
"""Analyzes the dataset to find key characteristics and opportunities for storytelling."""
|
| 374 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 375 |
+
categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
| 376 |
+
|
| 377 |
+
is_timeseries = any('date' in col.lower() or 'time' in col.lower() for col in df.columns)
|
| 378 |
+
|
| 379 |
+
opportunities = []
|
| 380 |
+
if is_timeseries:
|
| 381 |
+
opportunities.append("temporal trends")
|
| 382 |
+
if len(numeric_cols) > 1:
|
| 383 |
+
opportunities.append("correlations between metrics")
|
| 384 |
+
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
| 385 |
+
opportunities.append("segmentation by category")
|
| 386 |
+
if df.isnull().sum().sum() > 0:
|
| 387 |
+
opportunities.append("impact of missing data")
|
| 388 |
+
|
| 389 |
+
return {
|
| 390 |
+
"insight_opportunities": opportunities,
|
| 391 |
+
"is_timeseries": is_timeseries,
|
| 392 |
+
"has_correlations": len(numeric_cols) > 1,
|
| 393 |
+
"has_segments": len(categorical_cols) > 0 and len(numeric_cols) > 0
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
def generate_visualization_strategy(intelligence: Dict) -> str:
|
| 397 |
+
"""Generates dynamic advice on which charts to use."""
|
| 398 |
+
strategy = "Vary your visualizations to keep the report engaging. "
|
| 399 |
+
if intelligence["is_timeseries"]:
|
| 400 |
+
strategy += "Use 'line' or 'area' charts to explore temporal trends. "
|
| 401 |
+
if intelligence["has_correlations"]:
|
| 402 |
+
strategy += "Use 'scatter' or 'heatmap' charts to reveal correlations. "
|
| 403 |
+
if intelligence["has_segments"]:
|
| 404 |
+
strategy += "Use 'bar' or 'pie' charts to compare segments. "
|
| 405 |
+
return strategy
|
| 406 |
+
|
| 407 |
def get_augmented_context(df: pd.DataFrame, user_ctx: str) -> Dict:
|
| 408 |
"""Creates a detailed summary of the dataframe for the AI."""
|
| 409 |
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
|
|
|
| 436 |
return json.loads(json.dumps(context, default=str))
|
| 437 |
|
| 438 |
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
|
| 439 |
+
logging.info(f"Generating persona-driven report draft for project {project_id}")
|
| 440 |
df = load_dataframe_safely(buf, name)
|
| 441 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
|
| 442 |
|
| 443 |
+
# --- Try/Fallback Context Strategy ---
|
| 444 |
data_context_str = ""
|
| 445 |
context_for_charts = {}
|
| 446 |
try:
|
| 447 |
df_json = df.to_json(orient='records')
|
| 448 |
estimated_tokens = len(df_json) / 4
|
| 449 |
if estimated_tokens < MAX_CONTEXT_TOKENS:
|
| 450 |
+
logging.info(f"Using full JSON context.")
|
| 451 |
data_context_str = f"Here is the full dataset in JSON format:\n{df_json}"
|
| 452 |
context_for_charts = get_augmented_context(df, ctx)
|
| 453 |
else:
|
| 454 |
+
raise ValueError("Dataset too large.")
|
| 455 |
except Exception as e:
|
| 456 |
+
logging.warning(f"Falling back to augmented summary context: {e}")
|
| 457 |
augmented_context = get_augmented_context(df, ctx)
|
| 458 |
data_context_str = f"The full dataset is too large to display. Here is a detailed summary:\n{json.dumps(augmented_context, indent=2)}"
|
| 459 |
context_for_charts = augmented_context
|
| 460 |
|
| 461 |
+
# --- Persona-Driven Prompting ---
|
| 462 |
+
intelligence = analyze_data_intelligence(df)
|
| 463 |
+
viz_strategy = generate_visualization_strategy(intelligence)
|
| 464 |
+
|
| 465 |
report_prompt = f"""
|
| 466 |
+
You are an elite data storyteller and business intelligence expert. Your mission is to uncover the compelling, hidden narrative in this dataset and present it as a captivating story in Markdown format that drives action.
|
| 467 |
|
| 468 |
**Data Context:**
|
| 469 |
{data_context_str}
|
| 470 |
|
| 471 |
+
**Intelligence Analysis:**
|
| 472 |
+
- The most interesting parts of this story may lie in the following areas: {', '.join(intelligence['insight_opportunities'])}.
|
| 473 |
+
- Weave these threads into your core narrative.
|
| 474 |
+
|
| 475 |
+
**Visualization Strategy:**
|
| 476 |
+
- {viz_strategy}
|
| 477 |
+
- Available Chart Types: `bar, pie, line, scatter, hist, heatmap, area, bubble`.
|
| 478 |
+
|
| 479 |
+
**Your Grounding Rules (Most Important):**
|
| 480 |
+
1. **Strict Accuracy:** Your entire analysis and narrative **must strictly** use the column names provided in the 'Data Context' section. Do not invent, modify, or assume any column names that are not on this list.
|
| 481 |
+
2. **Chart Support:** Wherever a key finding is made, you **must** support it with a chart tag: `<generate_chart: "chart_type | a specific, compelling description">`.
|
| 482 |
+
3. **Chart Accuracy:** The column names used in your chart descriptions **must** also be an exact match from the provided data context.
|
| 483 |
|
| 484 |
+
Now, begin your report. Let the data's story unfold naturally.
|
| 485 |
"""
|
| 486 |
|
| 487 |
md = llm.invoke(report_prompt).content
|