KoreAI-API / sozo_gen.py
rairo's picture
Update sozo_gen.py
86cffd3 verified
raw
history blame
43.8 kB
# sozo_gen.py
import os
import re
import json
import logging
import uuid
import io
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter
from PIL import Image
import cv2
import inspect
import tempfile
import subprocess
from typing import Dict, List
from langchain_google_genai import ChatGoogleGenerativeAI
from google import genai
import requests
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
FPS, WIDTH, HEIGHT = 24, 1280, 720
MAX_CHARTS, VIDEO_SCENES = 5, 5
# --- Gemini API Initialization ---
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
raise ValueError("GOOGLE_API_KEY environment variable not set.")
# --- Helper Functions ---
def load_dataframe_safely(buf, name: str):
ext = Path(name).suffix.lower()
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
df.columns = df.columns.astype(str).str.strip()
df = df.dropna(how="all")
if df.empty or len(df.columns) == 0: raise ValueError("No usable data found")
return df
def deepgram_tts(txt: str, voice_model: str):
DG_KEY = os.getenv("DEEPGRAM_API_KEY")
if not DG_KEY or not txt: return None
txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
try:
r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30)
r.raise_for_status()
return r.content
except Exception as e:
logging.error(f"Deepgram TTS failed: {e}")
return None
def generate_silence_mp3(duration: float, out: Path):
subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True)
def audio_duration(path: str) -> float:
try:
res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
return float(res.stdout.strip())
except Exception: return 5.0
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
def clean_narration(txt: str) -> str:
txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt)
phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
return re.sub(r"\s{2,}", " ", txt).strip()
def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
def generate_image_from_prompt(prompt: str) -> Image.Image:
model_main = "gemini-1.5-flash-latest";
full_prompt = "A clean business-presentation illustration: " + prompt
try:
model = genai.GenerativeModel(model_main)
res = model.generate_content(full_prompt)
img_part = next((part for part in res.candidates[0].content.parts if getattr(part, "inline_data", None)), None)
if img_part:
return Image.open(io.BytesIO(img_part.inline_data.data)).convert("RGB")
return placeholder_img()
except Exception:
return placeholder_img()
# --- Chart Generation System ---
class ChartSpecification:
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
return enhanced_ctx
class ChartGenerator:
def __init__(self, llm, df: pd.DataFrame):
self.llm = llm; self.df = df
self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
def generate_chart_spec(self, description: str) -> ChartSpecification:
spec_prompt = f"""
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
**Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
**Chart Request:** {description}
**Return a JSON specification with these exact fields:**
{{
"chart_type": "bar|pie|line|scatter|hist", "title": "Professional chart title", "x_col": "column_name_for_x_axis",
"y_col": "column_name_for_y_axis_or_null", "agg_method": "sum|mean|count|max|min|null", "top_n": "number_for_top_n_filtering_or_null"
}}
Return only the JSON specification, no additional text.
"""
try:
response = self.llm.invoke(spec_prompt).content.strip()
if response.startswith("```json"): response = response[7:-3]
elif response.startswith("```"): response = response[3:-3]
spec_dict = json.loads(response)
valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']]
filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys}
return ChartSpecification(**filtered_dict)
except Exception as e:
logging.error(f"Spec generation failed: {e}. Using fallback.")
return self._create_fallback_spec(description)
def _create_fallback_spec(self, description: str) -> ChartSpecification:
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
ctype = "bar"
for t in ["pie", "line", "scatter", "hist"]:
if t in description.lower(): ctype = t
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
return ChartSpecification(ctype, description, x, y)
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
try:
plot_data = prepare_plot_data(spec, df)
fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default')
if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45)
elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal')
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
return True
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
if spec.chart_type in ["bar", "pie"]:
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
return grouped.nlargest(spec.top_n or 10)
elif spec.chart_type == "line": return df.set_index(spec.x_col)[spec.y_col].sort_index()
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
elif spec.chart_type == "hist": return df[spec.x_col].dropna()
return df[spec.x_col]
# --- Animation & Video Generation ---
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
plot_data = prepare_plot_data(spec, df)
frames = max(10, int(dur * fps))
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
plt.tight_layout(pad=3.0)
ctype = spec.chart_type
if ctype == "pie":
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
ax.set_title(spec.title); ax.axis('equal')
def init(): [w.set_alpha(0) for w in wedges]; return wedges
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
elif ctype == "bar":
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
def init(): return bars
def update(i):
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
return bars
elif ctype == "scatter":
scat = ax.scatter([], [], alpha=0.7)
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
def init(): scat.set_offsets(np.empty((0, 2))); return [scat]
def update(i):
k = max(1, int(len(x_full) * (i / (frames - 1))))
scat.set_offsets(plot_data.iloc[:k].values); return [scat]
elif ctype == "hist":
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
def init(): [p.set_alpha(0) for p in patches]; return patches
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
else: # line
line, = ax.plot([], [], lw=2)
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
x_full, y_full = plot_data.index, plot_data.values
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
def init(): line.set_data([], []); return [line]
def update(i):
k = max(2, int(len(x_full) * (i / (frames - 1))))
line.set_data(x_full[:k], y_full[:k]); return [line]
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
plt.close(fig)
return str(out)
def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
total_frames = max(1, int(dur * fps))
for i in range(total_frames):
alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
video_writer.write(frame)
video_writer.release()
return str(out)
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
try:
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
chart_spec = chart_generator.generate_chart_spec(desc)
return animate_chart(chart_spec, df, dur, out)
except Exception as e:
logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
temp_png = Path(temp_png_file.name)
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, temp_png):
img = cv2.imread(str(temp_png)); os.unlink(temp_png)
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
return animate_image_fade(img_resized, dur, out)
else:
img = generate_image_from_prompt(f"A professional business chart showing {desc}")
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
return animate_image_fade(img_cv, dur, out)
def concat_media(file_paths: List[str], output_path: Path):
valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
if not valid_paths: raise ValueError("No valid media files to concatenate.")
if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
list_file = output_path.with_suffix(".txt")
with open(list_file, 'w') as f:
for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
finally:
list_file.unlink(missing_ok=True)
# --- Main Business Logic Functions for Flask ---
# ADD THIS NEW HELPER FUNCTION SOMEWHERE NEAR THE TOP OF THE FILE
def sanitize_for_firebase_key(text: str) -> str:
"""Replaces Firebase-forbidden characters in a string with underscores."""
forbidden_chars = ['.', '$', '#', '[', ']', '/']
for char in forbidden_chars:
text = text.replace(char, '_')
return text
# REPLACE THE OLD generate_report_draft WITH THIS CORRECTED VERSION
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
"""
Enhanced autonomous data analysis function that intelligently analyzes any dataset
and generates comprehensive, domain-appropriate reports with contextual visualizations.
Maintains backward compatibility with existing function signature and outputs.
"""
logging.info(f"Generating enhanced autonomous report draft for project {project_id}")
# Load data safely (existing functionality preserved)
df = load_dataframe_safely(buf, name)
# Initialize LLM (existing setup preserved)
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
# Enhanced autonomous data analysis
try:
# Stage 1: Intelligent Data Classification and Deep Analysis
autonomous_context = perform_autonomous_data_analysis(df, ctx, name)
# Stage 2: Generate Enhanced Report with Intelligent Narrative
enhanced_report = generate_intelligent_report(llm, autonomous_context)
# Stage 3: Smart Chart Generation
chart_urls = generate_autonomous_charts(llm, df, enhanced_report, uid, project_id, bucket)
# Preserve original output structure
return {"raw_md": enhanced_report, "chartUrls": chart_urls}
except Exception as e:
logging.error(f"Enhanced analysis failed, falling back to original: {str(e)}")
# Fallback to original logic if enhancement fails
return generate_original_report(df, llm, ctx, uid, project_id, bucket)
def perform_autonomous_data_analysis(df: pd.DataFrame, user_ctx: str, filename: str) -> Dict[str, Any]:
"""
Performs comprehensive autonomous analysis of the dataset to understand its nature,
domain, and analytical potential.
"""
logging.info("Performing autonomous data analysis...")
# Basic data profiling
basic_info = {
"shape": df.shape,
"columns": list(df.columns),
"dtypes": df.dtypes.to_dict(),
"filename": filename,
"user_context": user_ctx
}
# Intelligent domain classification
domain_analysis = classify_dataset_domain(df, filename)
# Advanced statistical analysis
statistical_profile = generate_statistical_profile(df)
# Relationship discovery
relationships = discover_data_relationships(df)
# Temporal analysis if applicable
temporal_insights = analyze_temporal_patterns(df)
# Data quality assessment
quality_metrics = assess_data_quality(df)
# Business context inference
business_context = infer_business_context(df, domain_analysis)
return {
"basic_info": basic_info,
"domain": domain_analysis,
"statistics": statistical_profile,
"relationships": relationships,
"temporal": temporal_insights,
"quality": quality_metrics,
"business_context": business_context,
"analysis_complexity": determine_analysis_complexity(df, domain_analysis)
}
def classify_dataset_domain(df: pd.DataFrame, filename: str) -> Dict[str, Any]:
"""
Intelligently classifies the dataset domain based on column patterns, data types,
and semantic analysis.
"""
domain_indicators = {
"financial": ["amount", "price", "cost", "revenue", "profit", "transaction", "payment", "invoice"],
"survey": ["rating", "satisfaction", "response", "score", "survey", "feedback", "opinion"],
"scientific": ["measurement", "experiment", "test", "sample", "observation", "hypothesis", "variable"],
"marketing": ["campaign", "click", "conversion", "customer", "lead", "acquisition", "retention"],
"operational": ["process", "time", "duration", "status", "workflow", "performance", "efficiency"],
"sales": ["order", "product", "quantity", "sales", "customer", "deal", "pipeline"],
"hr": ["employee", "salary", "department", "performance", "training", "recruitment"],
"healthcare": ["patient", "diagnosis", "treatment", "medical", "health", "symptom", "medication"]
}
# Analyze column names for domain indicators
columns_lower = [col.lower() for col in df.columns]
domain_scores = {}
for domain, keywords in domain_indicators.items():
score = sum(1 for col in columns_lower for keyword in keywords if keyword in col)
domain_scores[domain] = score
# Filename analysis
filename_lower = filename.lower()
for domain, keywords in domain_indicators.items():
if any(keyword in filename_lower for keyword in keywords):
domain_scores[domain] = domain_scores.get(domain, 0) + 2
# Data type analysis
numeric_ratio = len(df.select_dtypes(include=[np.number]).columns) / len(df.columns)
categorical_ratio = len(df.select_dtypes(include=['object']).columns) / len(df.columns)
# Determine primary domain
primary_domain = max(domain_scores, key=domain_scores.get) if domain_scores else "general"
return {
"primary_domain": primary_domain,
"domain_confidence": domain_scores.get(primary_domain, 0),
"domain_scores": domain_scores,
"data_characteristics": {
"numeric_ratio": numeric_ratio,
"categorical_ratio": categorical_ratio,
"is_time_series": detect_time_series(df),
"is_transactional": detect_transactional_data(df),
"is_experimental": detect_experimental_data(df)
}
}
def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]:
"""
Generates comprehensive statistical profile of the dataset.
"""
profile = {
"summary_stats": {},
"correlations": {},
"distributions": {},
"outliers": {},
"missing_data": {}
}
# Summary statistics for numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
profile["summary_stats"] = df[numeric_cols].describe().to_dict()
# Correlation analysis
if len(numeric_cols) > 1:
corr_matrix = df[numeric_cols].corr()
# Find strong correlations
strong_corrs = []
for i in range(len(corr_matrix.columns)):
for j in range(i+1, len(corr_matrix.columns)):
corr_val = corr_matrix.iloc[i, j]
if abs(corr_val) > 0.7: # Strong correlation threshold
strong_corrs.append({
"var1": corr_matrix.columns[i],
"var2": corr_matrix.columns[j],
"correlation": corr_val
})
profile["correlations"] = {"strong_correlations": strong_corrs}
# Categorical analysis
categorical_cols = df.select_dtypes(include=['object']).columns
if len(categorical_cols) > 0:
profile["categorical_analysis"] = {}
for col in categorical_cols:
profile["categorical_analysis"][col] = {
"unique_count": df[col].nunique(),
"top_values": df[col].value_counts().head(5).to_dict()
}
# Missing data analysis
missing_data = df.isnull().sum()
profile["missing_data"] = {
"columns_with_missing": missing_data[missing_data > 0].to_dict(),
"total_missing_percentage": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100
}
return profile
def discover_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
"""
Discovers meaningful relationships and patterns in the data.
"""
relationships = {
"key_relationships": [],
"patterns": [],
"anomalies": []
}
# Identify potential key relationships
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 1:
# Find interesting relationships
for col1 in numeric_cols:
for col2 in numeric_cols:
if col1 != col2:
correlation = df[col1].corr(df[col2])
if abs(correlation) > 0.5: # Moderate to strong correlation
relationships["key_relationships"].append({
"variable1": col1,
"variable2": col2,
"relationship_strength": correlation,
"relationship_type": "positive" if correlation > 0 else "negative"
})
# Identify patterns in categorical data
categorical_cols = df.select_dtypes(include=['object']).columns
for col in categorical_cols:
if df[col].nunique() < 20: # Reasonable number of categories
value_counts = df[col].value_counts()
if len(value_counts) > 0:
relationships["patterns"].append({
"column": col,
"pattern_type": "categorical_distribution",
"dominant_category": value_counts.index[0],
"dominance_percentage": (value_counts.iloc[0] / len(df)) * 100
})
return relationships
def analyze_temporal_patterns(df: pd.DataFrame) -> Dict[str, Any]:
"""
Analyzes temporal patterns if time-based columns are detected.
"""
temporal_insights = {"has_temporal_data": False}
# Detect date/time columns
date_columns = []
for col in df.columns:
if df[col].dtype == 'datetime64[ns]' or 'date' in col.lower() or 'time' in col.lower():
try:
pd.to_datetime(df[col])
date_columns.append(col)
except:
continue
if date_columns:
temporal_insights["has_temporal_data"] = True
temporal_insights["date_columns"] = date_columns
# Analyze temporal patterns for the first date column
primary_date_col = date_columns[0]
df_temp = df.copy()
df_temp[primary_date_col] = pd.to_datetime(df_temp[primary_date_col])
temporal_insights["temporal_analysis"] = {
"date_range": {
"start": df_temp[primary_date_col].min().strftime('%Y-%m-%d'),
"end": df_temp[primary_date_col].max().strftime('%Y-%m-%d')
},
"time_span_days": (df_temp[primary_date_col].max() - df_temp[primary_date_col].min()).days,
"frequency": detect_temporal_frequency(df_temp[primary_date_col])
}
return temporal_insights
def assess_data_quality(df: pd.DataFrame) -> Dict[str, Any]:
"""
Assesses data quality and identifies potential issues.
"""
quality_metrics = {
"overall_quality_score": 0,
"quality_issues": [],
"data_completeness": 0,
"data_consistency": {}
}
# Completeness assessment
completeness = (1 - df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100
quality_metrics["data_completeness"] = completeness
# Identify quality issues
if completeness < 95:
quality_metrics["quality_issues"].append("Missing data detected")
# Check for duplicates
duplicate_rows = df.duplicated().sum()
if duplicate_rows > 0:
quality_metrics["quality_issues"].append(f"{duplicate_rows} duplicate rows found")
# Check for inconsistent data types
for col in df.columns:
if df[col].dtype == 'object':
if df[col].str.isnumeric().any() and not df[col].str.isnumeric().all():
quality_metrics["quality_issues"].append(f"Inconsistent data types in {col}")
# Calculate overall quality score
base_score = 100
base_score -= (100 - completeness) * 0.5 # Penalize missing data
base_score -= len(quality_metrics["quality_issues"]) * 5 # Penalize each quality issue
quality_metrics["overall_quality_score"] = max(0, base_score)
return quality_metrics
def infer_business_context(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""
Infers business context and potential use cases based on the data characteristics.
"""
domain = domain_analysis["primary_domain"]
context_mapping = {
"financial": {
"key_metrics": ["Revenue", "Profit", "Cost", "ROI"],
"typical_analyses": ["Trend analysis", "Profitability analysis", "Budget vs actual"],
"stakeholders": ["CFO", "Finance team", "Executive leadership"]
},
"survey": {
"key_metrics": ["Satisfaction scores", "Response rates", "Sentiment"],
"typical_analyses": ["Satisfaction analysis", "Demographic breakdown", "Correlation analysis"],
"stakeholders": ["Marketing team", "Product managers", "Customer success"]
},
"scientific": {
"key_metrics": ["Statistical significance", "Effect size", "Confidence intervals"],
"typical_analyses": ["Hypothesis testing", "Regression analysis", "Experimental validation"],
"stakeholders": ["Researchers", "Scientists", "Academic community"]
},
"marketing": {
"key_metrics": ["Conversion rates", "Customer acquisition cost", "Campaign ROI"],
"typical_analyses": ["Campaign performance", "Customer segmentation", "Attribution analysis"],
"stakeholders": ["Marketing team", "CMO", "Sales team"]
}
}
return context_mapping.get(domain, {
"key_metrics": ["Performance indicators", "Trends", "Patterns"],
"typical_analyses": ["Descriptive analysis", "Trend identification", "Pattern recognition"],
"stakeholders": ["Business stakeholders", "Decision makers"]
})
def generate_intelligent_report(llm, autonomous_context: Dict[str, Any]) -> str:
"""
Generates an intelligent, domain-appropriate report with organic storytelling.
"""
# Create truly autonomous prompt that lets AI decide everything
enhanced_prompt = f"""
You are a world-class data analyst who has just been handed this dataset to analyze. Look at the data characteristics and tell me the most compelling story you can find.
**DATASET CONTEXT:**
{json.dumps(autonomous_context, indent=2)}
**YOUR MISSION:**
Analyze this data like you would if a CEO walked into your office and said "I need to understand what this data is telling us." Write a report that would make them say "This is exactly what I needed to know."
**GUIDELINES:**
- Don't follow a rigid structure - let the data guide your narrative
- Choose your own headings and sections based on what the data reveals
- Write like you're presenting findings to someone who needs to make important decisions
- Include specific numbers and insights that matter
- Insert chart recommendations like: `<generate_chart: "chart_type | description">`
- Valid chart types: bar, pie, line, scatter, hist, box, heatmap
- Only recommend charts that truly support your narrative
**FORGET TEMPLATES - TELL THE STORY:**
What's the most interesting, important, or surprising thing this data reveals? Start there and build your entire report around that central insight. Make it compelling, make it actionable, make it memorable.
Be the data analyst who gets promoted because they don't just present data - they reveal insights that drive business decisions.
"""
# Removed - no longer needed since we're letting AI decide everything organically
def generate_autonomous_charts(llm, df: pd.DataFrame, report_md: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
"""
Generates charts autonomously based on the report content and data characteristics.
"""
# Extract chart descriptions from the enhanced report
chart_descs = extract_chart_tags(report_md)[:MAX_CHARTS]
chart_urls = {}
if not chart_descs:
# If no charts specified, generate intelligent defaults
chart_descs = generate_intelligent_chart_suggestions(df, llm)
chart_generator = ChartGenerator(llm, df)
for desc in chart_descs:
try:
# Create a safe key for Firebase
safe_desc = sanitize_for_firebase_key(desc)
# Replace chart tags in markdown
report_md = report_md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
report_md = report_md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
# Generate chart
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
chart_urls[safe_desc] = blob.public_url
logging.info(f"Generated autonomous chart: {safe_desc}")
finally:
if os.path.exists(img_path):
os.unlink(img_path)
except Exception as e:
logging.error(f"Failed to generate chart '{desc}': {str(e)}")
continue
return chart_urls
def generate_intelligent_chart_suggestions(df: pd.DataFrame, llm) -> List[str]:
"""
Generates intelligent chart suggestions based on data characteristics.
"""
numeric_cols = df.select_dtypes(include=[np.number]).columns
categorical_cols = df.select_dtypes(include=['object']).columns
suggestions = []
# Time series chart if temporal data exists
if detect_time_series(df):
suggestions.append("line | Time series trend analysis | Show temporal patterns")
# Distribution chart for numeric data
if len(numeric_cols) > 0:
main_numeric = numeric_cols[0]
suggestions.append(f"hist | Distribution of {main_numeric} | Understand data distribution")
# Correlation analysis if multiple numeric columns
if len(numeric_cols) > 1:
suggestions.append("scatter | Correlation analysis | Identify relationships between variables")
# Categorical breakdown
if len(categorical_cols) > 0:
main_categorical = categorical_cols[0]
suggestions.append(f"bar | {main_categorical} breakdown | Show categorical distribution")
return suggestions[:MAX_CHARTS]
# Helper functions (preserve existing functionality)
def detect_time_series(df: pd.DataFrame) -> bool:
"""Detect if dataset contains time series data."""
for col in df.columns:
if 'date' in col.lower() or 'time' in col.lower():
return True
try:
pd.to_datetime(df[col])
return True
except:
continue
return False
def detect_transactional_data(df: pd.DataFrame) -> bool:
"""Detect if dataset contains transactional data."""
transaction_indicators = ['transaction', 'payment', 'order', 'invoice', 'amount', 'quantity']
columns_lower = [col.lower() for col in df.columns]
return any(indicator in col for col in columns_lower for indicator in transaction_indicators)
def detect_experimental_data(df: pd.DataFrame) -> bool:
"""Detect if dataset contains experimental data."""
experimental_indicators = ['test', 'experiment', 'trial', 'group', 'treatment', 'control']
columns_lower = [col.lower() for col in df.columns]
return any(indicator in col for col in columns_lower for indicator in experimental_indicators)
def detect_temporal_frequency(date_series: pd.Series) -> str:
"""Detect the frequency of temporal data."""
if len(date_series) < 2:
return "insufficient_data"
# Calculate time differences
time_diffs = date_series.sort_values().diff().dropna()
median_diff = time_diffs.median()
if median_diff <= pd.Timedelta(days=1):
return "daily"
elif median_diff <= pd.Timedelta(days=7):
return "weekly"
elif median_diff <= pd.Timedelta(days=31):
return "monthly"
else:
return "irregular"
def determine_analysis_complexity(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> str:
"""Determine the complexity level of analysis required."""
complexity_factors = 0
# Data size factor
if len(df) > 10000:
complexity_factors += 1
if len(df.columns) > 20:
complexity_factors += 1
# Data type diversity
if len(df.select_dtypes(include=[np.number]).columns) > 5:
complexity_factors += 1
if len(df.select_dtypes(include=['object']).columns) > 5:
complexity_factors += 1
# Domain complexity
if domain_analysis["primary_domain"] in ["scientific", "financial"]:
complexity_factors += 1
if complexity_factors >= 3:
return "high"
elif complexity_factors >= 2:
return "medium"
else:
return "low"
def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
"""
Fallback to original report generation logic if enhanced version fails.
"""
logging.info("Using fallback report generation")
# Original logic preserved
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
enhanced_ctx = enhance_data_context(df, ctx_dict)
report_prompt = f"""
You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
**Instructions:**
1. **Executive Summary**: Start with a high-level summary of key findings.
2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
Valid chart types: bar, pie, line, scatter, hist.
Generate insights that would be valuable to C-level executives.
"""
md = llm.invoke(report_prompt).content
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
chart_urls = {}
chart_generator = ChartGenerator(llm, df)
for desc in chart_descs:
safe_desc = sanitize_for_firebase_key(desc)
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
chart_urls[safe_desc] = blob.public_url
finally:
if os.path.exists(img_path):
os.unlink(img_path)
return {"raw_md": md, "chartUrls": chart_urls}
def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
"""
Generates a basic fallback report when enhanced generation fails.
"""
basic_info = autonomous_context["basic_info"]
domain = autonomous_context["domain"]["primary_domain"]
return f"""
# What This Data Reveals
Looking at this {domain} dataset with {basic_info['shape'][0]} records, there are several key insights worth highlighting.
## The Numbers Tell a Story
This dataset contains {basic_info['shape'][1]} different variables, suggesting a comprehensive view of the underlying processes or behaviors being measured.
<generate_chart: "bar | Data overview showing key metrics">
## What You Should Know
The data structure and patterns suggest this is worth deeper investigation. The variety of data types and relationships indicate multiple analytical opportunities.
## Next Steps
Based on this initial analysis, I recommend diving deeper into the specific patterns and relationships within the data to unlock more actionable insights.
*Note: This is a simplified analysis. Enhanced storytelling temporarily unavailable.*
"""
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
logging.info(f"Generating single chart '{description}' for project {project_id}")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(description)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
logging.info(f"Uploaded single chart to {blob.public_url}")
return blob.public_url
finally:
if os.path.exists(img_path):
os.unlink(img_path)
return None
def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.2)
story_prompt = f"Based on the following report, create a script for a {VIDEO_SCENES}-scene video. Each scene must be separated by '[SCENE_BREAK]' and contain narration and one chart tag. Report: {raw_md}"
script = llm.invoke(story_prompt).content
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
video_parts, audio_parts, temps = [], [], []
for sc in scenes:
descs, narrative = extract_chart_tags(sc), clean_narration(sc)
audio_bytes = deepgram_tts(narrative, voice_model)
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
if audio_bytes:
mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3))
if dur <= 0.1: dur = 5.0
else:
dur = 5.0; generate_silence_mp3(dur, mp3)
audio_parts.append(str(mp3)); temps.append(mp3)
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
if descs: safe_chart(descs[0], df, dur, mp4)
else:
img = generate_image_from_prompt(narrative)
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
animate_image_fade(img_cv, dur, mp4)
video_parts.append(str(mp4)); temps.append(mp4)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid:
silent_vid_path = Path(temp_vid.name)
audio_mix_path = Path(temp_aud.name)
final_vid_path = Path(final_vid.name)
concat_media(video_parts, silent_vid_path)
concat_media(audio_parts, audio_mix_path)
subprocess.run(
["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
"-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
check=True, capture_output=True,
)
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(final_vid_path))
logging.info(f"Uploaded video to {blob.public_url}")
for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]:
if os.path.exists(p): os.unlink(p)
return blob.public_url
return None