sbs-API / sozo_gen.py
rairo's picture
Update sozo_gen.py
bd8a2e9 verified
raw
history blame
51.5 kB
# sozo_gen.py
# sozo_gen.py
import os
import re
import json
import logging
import uuid
import io
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter
from PIL import Image
import cv2
import inspect
import tempfile
import subprocess
from typing import Dict, List, Tuple, Any
from langchain_google_genai import ChatGoogleGenerativeAI
from google import genai
import requests
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
FPS, WIDTH, HEIGHT = 24, 1280, 720
MAX_CHARTS, VIDEO_SCENES = 5, 5
# --- Gemini API Initialization ---
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
raise ValueError("GOOGLE_API_KEY environment variable not set.")
# --- Helper Functions ---
def load_dataframe_safely(buf, name: str):
ext = Path(name).suffix.lower()
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
df.columns = df.columns.astype(str).str.strip()
df = df.dropna(how="all")
if df.empty or len(df.columns) == 0: raise ValueError("No usable data found")
return df
def deepgram_tts(txt: str, voice_model: str):
DG_KEY = os.getenv("DEEPGRAM_API_KEY")
if not DG_KEY or not txt: return None
txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
try:
r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30)
r.raise_for_status()
return r.content
except Exception as e:
logging.error(f"Deepgram TTS failed: {e}")
return None
def generate_silence_mp3(duration: float, out: Path):
subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True)
def audio_duration(path: str) -> float:
try:
res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
return float(res.stdout.strip())
except Exception: return 5.0
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
def clean_narration(txt: str) -> str:
txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt)
phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
return re.sub(r"\s{2,}", " ", txt).strip()
def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
def generate_image_from_prompt(prompt: str) -> Image.Image:
model_main = "gemini-2.0-flash-exp";
full_prompt = "A clean business-presentation illustration: " + prompt
try:
model = genai.GenerativeModel(model_main)
res = model.generate_content(full_prompt)
img_part = next((part for part in res.candidates[0].content.parts if getattr(part, "inline_data", None)), None)
if img_part:
return Image.open(io.BytesIO(img_part.inline_data.data)).convert("RGB")
return placeholder_img()
except Exception:
return placeholder_img()
# --- Chart Generation System ---
class ChartSpecification:
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
return enhanced_ctx
class ChartGenerator:
def __init__(self, llm, df: pd.DataFrame):
self.llm = llm; self.df = df
self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
def generate_chart_spec(self, description: str) -> ChartSpecification:
spec_prompt = f"""
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
**Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
**Chart Request:** {description}
**Return a JSON specification with these exact fields:**
{{
"chart_type": "bar|pie|line|scatter|hist", "title": "Professional chart title", "x_col": "column_name_for_x_axis",
"y_col": "column_name_for_y_axis_or_null", "agg_method": "sum|mean|count|max|min|null", "top_n": "number_for_top_n_filtering_or_null"
}}
Return only the JSON specification, no additional text.
"""
try:
response = self.llm.invoke(spec_prompt).content.strip()
if response.startswith("```json"): response = response[7:-3]
elif response.startswith("```"): response = response[3:-3]
spec_dict = json.loads(response)
valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']]
filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys}
return ChartSpecification(**filtered_dict)
except Exception as e:
logging.error(f"Spec generation failed: {e}. Using fallback.")
return self._create_fallback_spec(description)
def _create_fallback_spec(self, description: str) -> ChartSpecification:
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
ctype = "bar"
for t in ["pie", "line", "scatter", "hist"]:
if t in description.lower(): ctype = t
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
return ChartSpecification(ctype, description, x, y)
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
try:
plot_data = prepare_plot_data(spec, df)
fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default')
if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45)
elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal')
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
return True
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
if spec.chart_type in ["bar", "pie"]:
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
return grouped.nlargest(spec.top_n or 10)
elif spec.chart_type == "line": return df.set_index(spec.x_col)[spec.y_col].sort_index()
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
elif spec.chart_type == "hist": return df[spec.x_col].dropna()
return df[spec.x_col]
# --- Animation & Video Generation ---
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
plot_data = prepare_plot_data(spec, df)
frames = max(10, int(dur * fps))
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
plt.tight_layout(pad=3.0)
ctype = spec.chart_type
if ctype == "pie":
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
ax.set_title(spec.title); ax.axis('equal')
def init(): [w.set_alpha(0) for w in wedges]; return wedges
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
elif ctype == "bar":
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
def init(): return bars
def update(i):
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
return bars
elif ctype == "scatter":
scat = ax.scatter([], [], alpha=0.7)
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
def init(): scat.set_offsets(np.empty((0, 2))); return [scat]
def update(i):
k = max(1, int(len(x_full) * (i / (frames - 1))))
scat.set_offsets(plot_data.iloc[:k].values); return [scat]
elif ctype == "hist":
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
def init(): [p.set_alpha(0) for p in patches]; return patches
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
else: # line
line, = ax.plot([], [], lw=2)
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
x_full, y_full = plot_data.index, plot_data.values
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
def init(): line.set_data([], []); return [line]
def update(i):
k = max(2, int(len(x_full) * (i / (frames - 1))))
line.set_data(x_full[:k], y_full[:k]); return [line]
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
plt.close(fig)
return str(out)
def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
total_frames = max(1, int(dur * fps))
for i in range(total_frames):
alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
video_writer.write(frame)
video_writer.release()
return str(out)
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
try:
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
chart_spec = chart_generator.generate_chart_spec(desc)
return animate_chart(chart_spec, df, dur, out)
except Exception as e:
logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
temp_png = Path(temp_png_file.name)
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, temp_png):
img = cv2.imread(str(temp_png)); os.unlink(temp_png)
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
return animate_image_fade(img_resized, dur, out)
else:
img = generate_image_from_prompt(f"A professional business chart showing {desc}")
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
return animate_image_fade(img_cv, dur, out)
def concat_media(file_paths: List[str], output_path: Path):
valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
if not valid_paths: raise ValueError("No valid media files to concatenate.")
if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
list_file = output_path.with_suffix(".txt")
with open(list_file, 'w') as f:
for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
finally:
list_file.unlink(missing_ok=True)
# Backward-compatible fix: Override json.dumps to handle non-serializable types
import json as _json
def safe_json_dumps(obj, indent=2, **kwargs):
"""Safely serialize object to JSON, handling non-serializable types."""
def json_serializer(obj):
if isinstance(obj, (bool, int, float, str, type(None))):
return obj
elif isinstance(obj, (list, tuple)):
return [json_serializer(item) for item in obj]
elif isinstance(obj, dict):
return {key: json_serializer(value) for key, value in obj.items()}
else:
# Convert non-serializable objects to string representation
return str(obj)
try:
return _json.dumps(json_serializer(obj), indent=indent, **kwargs)
except Exception as e:
logging.warning(f"JSON serialization failed: {e}")
return str(obj)
# Monkey patch json.dumps to use our safe version
json.dumps = safe_json_dumps
def safe_firebase_data(obj):
"""
Safely prepare data for Firebase by ensuring all values are JSON-serializable
and Firebase-compatible.
"""
def clean_for_firebase(obj):
if obj is None:
return None
elif isinstance(obj, bool):
return obj
elif isinstance(obj, (int, float)):
return obj
elif isinstance(obj, str):
# Clean string for Firebase - remove null bytes and control characters
return ''.join(char for char in obj if ord(char) >= 32 or char in '\n\r\t')
elif isinstance(obj, (list, tuple)):
return [clean_for_firebase(item) for item in obj]
elif isinstance(obj, dict):
cleaned = {}
for key, value in obj.items():
# Firebase keys must be strings and can't contain certain characters
clean_key = str(key).replace('.', '_').replace('$', '_').replace('#', '_').replace('[', '_').replace(']', '_').replace('/', '_')
cleaned[clean_key] = clean_for_firebase(value)
return cleaned
else:
# Convert to string and clean
str_repr = str(obj)
return ''.join(char for char in str_repr if ord(char) >= 32 or char in '\n\r\t')
return clean_for_firebase(obj)
def concat_media(file_paths: List[str], output_path: Path):
valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
if not valid_paths: raise ValueError("No valid media files to concatenate.")
if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
list_file = output_path.with_suffix(".txt")
with open(list_file, 'w') as f:
for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
finally:
list_file.unlink(missing_ok=True)
# Backward-compatible fix: Override json.dumps to handle non-serializable types
import json as _json
def safe_json_dumps(obj, indent=2, **kwargs):
"""Safely serialize object to JSON, handling non-serializable types."""
def json_serializer(obj):
if isinstance(obj, (bool, int, float, str, type(None))):
return obj
elif isinstance(obj, (list, tuple)):
return [json_serializer(item) for item in obj]
elif isinstance(obj, dict):
return {key: json_serializer(value) for key, value in obj.items()}
else:
# Convert non-serializable objects to string representation
return str(obj)
try:
return _json.dumps(json_serializer(obj), indent=indent, **kwargs)
except Exception as e:
logging.warning(f"JSON serialization failed: {e}")
return str(obj)
# Monkey patch json.dumps to use our safe version
json.dumps = safe_json_dumps
# --- Main Business Logic Functions for Flask ---
# ADD THIS NEW HELPER FUNCTION SOMEWHERE NEAR THE TOP OF THE FILE
def sanitize_for_firebase_key(text: str) -> str:
"""Replaces Firebase-forbidden characters in a string with underscores."""
forbidden_chars = ['.', '$', '#', '[', ']', '/']
for char in forbidden_chars:
text = text.replace(char, '_')
return text
# REPLACE THE OLD generate_report_draft WITH THIS CORRECTED VERSION
from scipy import stats
import re
def analyze_data_intelligence(df: pd.DataFrame, ctx_dict: Dict) -> Dict[str, Any]:
"""
Autonomous data intelligence system that classifies domain,
detects patterns, and determines optimal analytical approach.
"""
# Domain Classification Engine
domain_signals = {
'financial': ['amount', 'price', 'cost', 'revenue', 'profit', 'balance', 'transaction', 'payment'],
'survey': ['rating', 'satisfaction', 'score', 'response', 'feedback', 'opinion', 'agree', 'likert'],
'scientific': ['measurement', 'experiment', 'trial', 'test', 'control', 'variable', 'hypothesis'],
'marketing': ['campaign', 'conversion', 'click', 'impression', 'engagement', 'customer', 'segment'],
'operational': ['performance', 'efficiency', 'throughput', 'capacity', 'utilization', 'process'],
'temporal': ['date', 'time', 'timestamp', 'period', 'month', 'year', 'day', 'hour']
}
# Analyze column patterns
columns_lower = [col.lower() for col in df.columns]
domain_scores = {}
for domain, keywords in domain_signals.items():
score = sum(1 for col in columns_lower if any(keyword in col for keyword in keywords))
domain_scores[domain] = score
# Determine primary domain
primary_domain = max(domain_scores, key=domain_scores.get) if max(domain_scores.values()) > 0 else 'general'
# Data Structure Analysis
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
# Detect time series
is_timeseries = len(datetime_cols) > 0 or any('date' in col.lower() or 'time' in col.lower() for col in columns_lower)
# Statistical Profile
statistical_summary = {}
if numeric_cols:
try:
correlations = df[numeric_cols].corr().abs().max()
correlations_dict = {k: float(v) if pd.notna(v) else 0.0 for k, v in correlations.to_dict().items()}
distributions = {}
for col in numeric_cols:
if len(df[col].dropna()) > 8:
try:
p_value = stats.normaltest(df[col].dropna())[1]
distributions[col] = 'normal' if p_value > 0.05 else 'non_normal'
except:
distributions[col] = 'unknown'
outliers = {}
for col in numeric_cols:
if len(df[col].dropna()) > 0:
try:
z_scores = np.abs(stats.zscore(df[col].dropna()))
outliers[col] = int(len(df[col][z_scores > 3]))
except:
outliers[col] = 0
statistical_summary = {
'correlations': correlations_dict,
'distributions': distributions,
'outliers': outliers
}
except Exception as e:
statistical_summary = {'error': 'Could not compute statistical summary'}
# Pattern Detection
patterns = {
'has_missing_data': df.isnull().sum().sum() > 0,
'has_duplicates': df.duplicated().sum() > 0,
'has_negative_values': any(df[col].min() < 0 for col in numeric_cols if len(df[col].dropna()) > 0),
'has_categorical_hierarchy': any(len(df[col].unique()) > 10 for col in categorical_cols),
'potential_segments': len(categorical_cols) > 0
}
# Insight Opportunities
insight_opportunities = []
if is_timeseries:
insight_opportunities.append("temporal_trends")
if len(numeric_cols) > 1:
insight_opportunities.append("correlations")
if len(categorical_cols) > 0 and len(numeric_cols) > 0:
insight_opportunities.append("segmentation")
if any(statistical_summary.get('outliers', {}).values()):
insight_opportunities.append("anomalies")
return {
'primary_domain': primary_domain,
'domain_confidence': domain_scores,
'data_structure': {
'is_timeseries': is_timeseries,
'numeric_cols': numeric_cols,
'categorical_cols': categorical_cols,
'datetime_cols': datetime_cols
},
'statistical_profile': statistical_summary,
'patterns': patterns,
'insight_opportunities': insight_opportunities,
'narrative_suggestions': get_narrative_suggestions(primary_domain, insight_opportunities, patterns)
}
def get_narrative_suggestions(domain: str, opportunities: List[str], patterns: Dict) -> Dict[str, str]:
"""Generate narrative direction based on domain and data characteristics"""
narrative_frameworks = {
'financial': {
'hook': "Follow the money trail that reveals your business's hidden opportunities",
'structure': "performance → trends → risks → opportunities",
'focus': "profitability, efficiency, growth patterns, risk indicators"
},
'survey': {
'hook': "Your customers are speaking - here's what they're really saying",
'structure': "sentiment → segments → drivers → actions",
'focus': "satisfaction drivers, demographic patterns, improvement areas"
},
'scientific': {
'hook': "The data reveals relationships that challenge conventional thinking",
'structure': "hypothesis → evidence → significance → implications",
'focus': "statistical significance, correlations, experimental validity"
},
'marketing': {
'hook': "Discover the customer journey patterns driving your growth",
'structure': "performance → segments → optimization → strategy",
'focus': "conversion funnels, customer segments, campaign effectiveness"
},
'operational': {
'hook': "Operational excellence lives in the details - here's where to look",
'structure': "efficiency → bottlenecks → optimization → impact",
'focus': "process efficiency, capacity utilization, improvement opportunities"
},
'general': {
'hook': "Every dataset tells a story - here's what yours is saying",
'structure': "overview → patterns → insights → implications",
'focus': "key patterns, significant relationships, actionable insights"
}
}
return narrative_frameworks.get(domain, narrative_frameworks['general'])
def json_serializable(obj):
"""Convert objects to JSON-serializable format"""
if isinstance(obj, (np.integer, np.floating)):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.bool_, bool)):
return bool(obj)
elif isinstance(obj, dict):
return {k: json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [json_serializable(item) for item in obj]
elif pd.isna(obj):
return None
else:
return obj
def create_autonomous_prompt(df: pd.DataFrame, enhanced_ctx: Dict, intelligence: Dict) -> str:
"""
Generate a dynamic, intelligence-driven prompt that creates compelling narratives
rather than following templates.
"""
domain = intelligence['primary_domain']
opportunities = intelligence['insight_opportunities']
narrative = intelligence['narrative_suggestions']
# Dynamic chart strategy based on data characteristics
chart_strategy = generate_chart_strategy(intelligence)
# Make context JSON serializable
serializable_ctx = json_serializable(enhanced_ctx)
prompt = f"""You are an elite data storyteller with deep expertise in {domain} analytics. Your mission is to uncover the compelling narrative hidden in this dataset and present it as a captivating story that drives action.
**THE DATA'S STORY CONTEXT:**
{json.dumps(serializable_ctx, indent=2)}
**INTELLIGENCE ANALYSIS:**
- Primary Domain: {domain}
- Key Opportunities: {', '.join(opportunities)}
- Data Characteristics: {json_serializable(intelligence['data_structure'])}
- Narrative Framework: {narrative['structure']}
**YOUR STORYTELLING MISSION:**
{narrative['hook']}
**NARRATIVE CONSTRUCTION GUIDELINES:**
1. **LEAD WITH INTRIGUE**: Start with the most compelling finding that hooks the reader
2. **BUILD TENSION**: Present contrasts, surprises, or unexpected patterns
3. **REVEAL INSIGHTS**: Use data to resolve the tension with clear explanations
4. **DRIVE ACTION**: End with specific, actionable recommendations
**VISUALIZATION STRATEGY:**
{chart_strategy}
**CRITICAL INSTRUCTIONS:**
- Write as if you're revealing a detective story, not filling a template
- Every insight must be supported by data evidence
- Use compelling headers that create curiosity (not "Executive Summary")
- Weave charts naturally into the narrative flow
- Focus on business impact and actionable outcomes
- Let the data's personality shine through your writing style
**CHART INTEGRATION:**
Insert charts using: `<generate_chart: "chart_type | compelling description that advances the story">`
Available types: bar, pie, line, scatter, hist
Transform this data into a story that decision-makers can't stop reading."""
return prompt
def generate_chart_strategy(intelligence: Dict) -> str:
"""Generate visualization strategy based on data intelligence"""
domain = intelligence['primary_domain']
opportunities = intelligence['insight_opportunities']
structure = intelligence['data_structure']
strategies = {
'financial': "Focus on trend lines showing performance over time, comparative bars for different categories, and scatter plots revealing correlations between financial metrics.",
'survey': "Emphasize distribution histograms for satisfaction scores, segmented bar charts for demographic breakdowns, and correlation matrices for response patterns.",
'scientific': "Prioritize scatter plots with regression lines, distribution comparisons, and statistical significance visualizations.",
'marketing': "Highlight conversion funnels, customer segment comparisons, and campaign performance trends.",
'operational': "Show efficiency trends, capacity utilization charts, and process performance comparisons."
}
base_strategy = strategies.get(domain, "Create visualizations that best tell your data's unique story.")
# Add specific guidance based on data characteristics
if structure['is_timeseries']:
base_strategy += " Leverage time-series visualizations to show trends and patterns over time."
if 'correlations' in opportunities:
base_strategy += " Include correlation visualizations to reveal hidden relationships."
if 'segmentation' in opportunities:
base_strategy += " Use segmented charts to highlight different groups or categories."
return base_strategy
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict[str, Any]:
"""Enhanced context generation with AI-driven analysis"""
# Get autonomous intelligence analysis
intelligence = analyze_data_intelligence(df, ctx_dict)
# Original context enhancement
enhanced = ctx_dict.copy()
# Add statistical context
if not df.empty:
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
key_metrics = {}
for col in numeric_cols[:3]: # Top 3 numeric columns
try:
mean_val = df[col].mean()
std_val = df[col].std()
key_metrics[col] = {
'mean': float(mean_val) if pd.notna(mean_val) else 0.0,
'std': float(std_val) if pd.notna(std_val) else 0.0
}
except:
key_metrics[col] = {'mean': 0.0, 'std': 0.0}
enhanced['statistical_summary'] = {
'numeric_columns': int(len(numeric_cols)),
'total_records': int(len(df)),
'missing_data_percentage': float((df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100),
'key_metrics': key_metrics
}
# Add categorical context
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
if len(categorical_cols) > 0:
unique_values = {}
for col in categorical_cols[:3]:
try:
unique_values[col] = int(df[col].nunique())
except:
unique_values[col] = 0
enhanced['categorical_summary'] = {
'categorical_columns': int(len(categorical_cols)),
'unique_values': unique_values
}
# Merge with intelligence analysis
enhanced['ai_intelligence'] = intelligence
return enhanced
def create_chart_safe_context(enhanced_ctx: Dict) -> Dict:
"""
Create a chart-generator-safe version of enhanced context
by ensuring all values are JSON serializable
"""
def make_json_safe(obj):
if isinstance(obj, bool):
return bool(obj)
elif isinstance(obj, (np.integer, np.floating)):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, dict):
return {k: make_json_safe(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [make_json_safe(item) for item in obj]
elif pd.isna(obj):
return None
elif hasattr(obj, 'item'): # numpy scalars
return obj.item()
else:
return obj
return make_json_safe(enhanced_ctx)
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
"""
Enhanced autonomous report generation with intelligent narrative creation
"""
logging.info(f"Generating autonomous report draft for project {project_id}")
df = load_dataframe_safely(buf, name)
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
# Build enhanced context with AI intelligence
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
enhanced_ctx = enhance_data_context(df, ctx_dict)
# Get AI intelligence analysis
intelligence = analyze_data_intelligence(df, ctx_dict)
# Generate autonomous prompt
report_prompt = create_autonomous_prompt(df, enhanced_ctx, intelligence)
# Generate the report
md = llm.invoke(report_prompt).content
# Extract and process charts
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
chart_urls = {}
# Create a chart-safe context
chart_safe_ctx = create_chart_safe_context(enhanced_ctx)
# Try to pass the safe context to ChartGenerator
try:
chart_generator = ChartGenerator(llm, df, chart_safe_ctx)
except TypeError:
# Fallback: if ChartGenerator doesn't accept enhanced_ctx parameter
chart_generator = ChartGenerator(llm, df)
# If it has an enhanced_ctx attribute, set it safely
if hasattr(chart_generator, 'enhanced_ctx'):
chart_generator.enhanced_ctx = chart_safe_ctx
for desc in chart_descs:
# Create a safe key for Firebase
safe_desc = sanitize_for_firebase_key(desc)
# Replace the original description in the markdown with the safe one
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">') # Handle no quotes case
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(desc) # Still generate spec from original desc
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
# Use the safe key in the dictionary
chart_urls[safe_desc] = blob.public_url
logging.info(f"Uploaded chart '{desc}' to {blob.public_url} with safe key '{safe_desc}'")
finally:
if os.path.exists(img_path):
os.unlink(img_path)
return {"raw_md": md, "chartUrls": chart_urls}
# Additional helper functions for the autonomous system
def detect_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
"""Detect relationships and patterns in the data"""
numeric_cols = df.select_dtypes(include=[np.number]).columns
relationships = {}
if len(numeric_cols) > 1:
corr_matrix = df[numeric_cols].corr()
# Find strong correlations (> 0.7 or < -0.7)
strong_correlations = []
for i in range(len(corr_matrix.columns)):
for j in range(i+1, len(corr_matrix.columns)):
corr_val = corr_matrix.iloc[i, j]
if abs(corr_val) > 0.7:
strong_correlations.append({
'var1': corr_matrix.columns[i],
'var2': corr_matrix.columns[j],
'correlation': corr_val
})
relationships['strong_correlations'] = strong_correlations
return relationships
def identify_key_metrics(df: pd.DataFrame, domain: str) -> List[str]:
"""Identify the most important metrics based on domain and data characteristics"""
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
domain_priorities = {
'financial': ['revenue', 'profit', 'cost', 'amount', 'price', 'margin'],
'survey': ['rating', 'score', 'satisfaction', 'response'],
'marketing': ['conversion', 'click', 'impression', 'engagement'],
'operational': ['efficiency', 'utilization', 'throughput', 'performance']
}
priorities = domain_priorities.get(domain, [])
key_metrics = []
# Match column names with domain priorities
for col in numeric_cols:
col_lower = col.lower()
for priority in priorities:
if priority in col_lower:
key_metrics.append(col)
break
# If no matches, use columns with highest variance (most interesting)
if not key_metrics and numeric_cols:
variances = df[numeric_cols].var().sort_values(ascending=False)
key_metrics = variances.head(3).index.tolist()
return key_metrics[:5] # Return top 5 key metrics
# Removed - no longer needed since we're letting AI decide everything organically
def generate_autonomous_charts(llm, df: pd.DataFrame, report_md: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
"""
Generates charts autonomously based on the report content and data characteristics.
"""
# Extract chart descriptions from the enhanced report
chart_descs = extract_chart_tags(report_md)[:MAX_CHARTS]
chart_urls = {}
if not chart_descs:
# If no charts specified, generate intelligent defaults
chart_descs = generate_intelligent_chart_suggestions(df, llm)
chart_generator = ChartGenerator(llm, df)
for desc in chart_descs:
try:
# Create a safe key for Firebase
safe_desc = sanitize_for_firebase_key(desc)
# Replace chart tags in markdown
report_md = report_md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
report_md = report_md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
# Generate chart
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
chart_urls[safe_desc] = blob.public_url
logging.info(f"Generated autonomous chart: {safe_desc}")
finally:
if os.path.exists(img_path):
os.unlink(img_path)
except Exception as e:
logging.error(f"Failed to generate chart '{desc}': {str(e)}")
continue
return chart_urls
def generate_intelligent_chart_suggestions(df: pd.DataFrame, llm) -> List[str]:
"""
Generates intelligent chart suggestions based on data characteristics.
"""
numeric_cols = df.select_dtypes(include=[np.number]).columns
categorical_cols = df.select_dtypes(include=['object']).columns
suggestions = []
# Time series chart if temporal data exists
if detect_time_series(df):
suggestions.append("line | Time series trend analysis | Show temporal patterns")
# Distribution chart for numeric data
if len(numeric_cols) > 0:
main_numeric = numeric_cols[0]
suggestions.append(f"hist | Distribution of {main_numeric} | Understand data distribution")
# Correlation analysis if multiple numeric columns
if len(numeric_cols) > 1:
suggestions.append("scatter | Correlation analysis | Identify relationships between variables")
# Categorical breakdown
if len(categorical_cols) > 0:
main_categorical = categorical_cols[0]
suggestions.append(f"bar | {main_categorical} breakdown | Show categorical distribution")
return suggestions[:MAX_CHARTS]
# Helper functions (preserve existing functionality)
def detect_time_series(df: pd.DataFrame) -> bool:
"""Detect if dataset contains time series data."""
for col in df.columns:
if 'date' in col.lower() or 'time' in col.lower():
return True
try:
pd.to_datetime(df[col])
return True
except:
continue
return False
def detect_transactional_data(df: pd.DataFrame) -> bool:
"""Detect if dataset contains transactional data."""
transaction_indicators = ['transaction', 'payment', 'order', 'invoice', 'amount', 'quantity']
columns_lower = [col.lower() for col in df.columns]
return any(indicator in col for col in columns_lower for indicator in transaction_indicators)
def detect_experimental_data(df: pd.DataFrame) -> bool:
"""Detect if dataset contains experimental data."""
experimental_indicators = ['test', 'experiment', 'trial', 'group', 'treatment', 'control']
columns_lower = [col.lower() for col in df.columns]
return any(indicator in col for col in columns_lower for indicator in experimental_indicators)
def detect_temporal_frequency(date_series: pd.Series) -> str:
"""Detect the frequency of temporal data."""
if len(date_series) < 2:
return "insufficient_data"
# Calculate time differences
time_diffs = date_series.sort_values().diff().dropna()
median_diff = time_diffs.median()
if median_diff <= pd.Timedelta(days=1):
return "daily"
elif median_diff <= pd.Timedelta(days=7):
return "weekly"
elif median_diff <= pd.Timedelta(days=31):
return "monthly"
else:
return "irregular"
def determine_analysis_complexity(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> str:
"""Determine the complexity level of analysis required."""
complexity_factors = 0
# Data size factor
if len(df) > 10000:
complexity_factors += 1
if len(df.columns) > 20:
complexity_factors += 1
# Data type diversity
if len(df.select_dtypes(include=[np.number]).columns) > 5:
complexity_factors += 1
if len(df.select_dtypes(include=['object']).columns) > 5:
complexity_factors += 1
# Domain complexity
if domain_analysis["primary_domain"] in ["scientific", "financial"]:
complexity_factors += 1
if complexity_factors >= 3:
return "high"
elif complexity_factors >= 2:
return "medium"
else:
return "low"
def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
"""
Fallback to original report generation logic if enhanced version fails.
"""
logging.info("Using fallback report generation")
# Original logic preserved
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
enhanced_ctx = enhance_data_context(df, ctx_dict)
report_prompt = f"""
You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
**Instructions:**
1. **Executive Summary**: Start with a high-level summary of key findings.
2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
Valid chart types: bar, pie, line, scatter, hist.
Generate insights that would be valuable to C-level executives.
"""
md = llm.invoke(report_prompt).content
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
chart_urls = {}
chart_generator = ChartGenerator(llm, df)
for desc in chart_descs:
safe_desc = sanitize_for_firebase_key(desc)
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
chart_urls[safe_desc] = blob.public_url
finally:
if os.path.exists(img_path):
os.unlink(img_path)
return {"raw_md": md, "chartUrls": chart_urls}
def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
"""
Generates a basic fallback report when enhanced generation fails.
"""
basic_info = autonomous_context["basic_info"]
domain = autonomous_context["domain"]["primary_domain"]
return f"""
# What This Data Reveals
Looking at this {domain} dataset with {basic_info['shape'][0]} records, there are several key insights worth highlighting.
## The Numbers Tell a Story
This dataset contains {basic_info['shape'][1]} different variables, suggesting a comprehensive view of the underlying processes or behaviors being measured.
<generate_chart: "bar | Data overview showing key metrics">
## What You Should Know
The data structure and patterns suggest this is worth deeper investigation. The variety of data types and relationships indicate multiple analytical opportunities.
## Next Steps
Based on this initial analysis, I recommend diving deeper into the specific patterns and relationships within the data to unlock more actionable insights.
*Note: This is a simplified analysis. Enhanced storytelling temporarily unavailable.*
"""
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
logging.info(f"Generating single chart '{description}' for project {project_id}")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(description)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
logging.info(f"Uploaded single chart to {blob.public_url}")
return blob.public_url
finally:
if os.path.exists(img_path):
os.unlink(img_path)
return None
def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.2)
story_prompt = f"Based on the following report, create a script for a {VIDEO_SCENES}-scene video. Each scene must be separated by '[SCENE_BREAK]' and contain narration and one chart tag. Report: {raw_md}"
script = llm.invoke(story_prompt).content
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
video_parts, audio_parts, temps = [], [], []
for sc in scenes:
descs, narrative = extract_chart_tags(sc), clean_narration(sc)
audio_bytes = deepgram_tts(narrative, voice_model)
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
if audio_bytes:
mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3))
if dur <= 0.1: dur = 5.0
else:
dur = 5.0; generate_silence_mp3(dur, mp3)
audio_parts.append(str(mp3)); temps.append(mp3)
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
if descs: safe_chart(descs[0], df, dur, mp4)
else:
img = generate_image_from_prompt(narrative)
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
animate_image_fade(img_cv, dur, mp4)
video_parts.append(str(mp4)); temps.append(mp4)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid:
silent_vid_path = Path(temp_vid.name)
audio_mix_path = Path(temp_aud.name)
final_vid_path = Path(final_vid.name)
concat_media(video_parts, silent_vid_path)
concat_media(audio_parts, audio_mix_path)
subprocess.run(
["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
"-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
check=True, capture_output=True,
)
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(final_vid_path))
logging.info(f"Uploaded video to {blob.public_url}")
for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]:
if os.path.exists(p): os.unlink(p)
return blob.public_url
return None