Spaces:
Running
Running
| # sozo_gen.py | |
| import os | |
| import re | |
| import json | |
| import logging | |
| import uuid | |
| import io | |
| from pathlib import Path | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.animation import FuncAnimation, FFMpegWriter | |
| from PIL import Image | |
| import cv2 | |
| import inspect | |
| import tempfile | |
| import subprocess | |
| from typing import Dict, List | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from google import genai | |
| import requests | |
| # --- Configuration --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s') | |
| FPS, WIDTH, HEIGHT = 24, 1280, 720 | |
| MAX_CHARTS, VIDEO_SCENES = 5, 5 | |
| # --- Gemini API Initialization --- | |
| API_KEY = os.getenv("GOOGLE_API_KEY") | |
| if not API_KEY: | |
| raise ValueError("GOOGLE_API_KEY environment variable not set.") | |
| # --- Helper Functions --- | |
| def load_dataframe_safely(buf, name: str): | |
| ext = Path(name).suffix.lower() | |
| df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf) | |
| df.columns = df.columns.astype(str).str.strip() | |
| df = df.dropna(how="all") | |
| if df.empty or len(df.columns) == 0: raise ValueError("No usable data found") | |
| return df | |
| def deepgram_tts(txt: str, voice_model: str): | |
| DG_KEY = os.getenv("DEEPGRAM_API_KEY") | |
| if not DG_KEY or not txt: return None | |
| txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000] | |
| try: | |
| r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30) | |
| r.raise_for_status() | |
| return r.content | |
| except Exception as e: | |
| logging.error(f"Deepgram TTS failed: {e}") | |
| return None | |
| def generate_silence_mp3(duration: float, out: Path): | |
| subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True) | |
| def audio_duration(path: str) -> float: | |
| try: | |
| res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
| return float(res.stdout.strip()) | |
| except Exception: return 5.0 | |
| TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, ) | |
| extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) ) | |
| re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M) | |
| def clean_narration(txt: str) -> str: | |
| txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt) | |
| phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"] | |
| for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE) | |
| txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt) | |
| return re.sub(r"\s{2,}", " ", txt).strip() | |
| def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230)) | |
| def generate_image_from_prompt(prompt: str) -> Image.Image: | |
| model_main = "gemini-1.5-flash-latest"; | |
| full_prompt = "A clean business-presentation illustration: " + prompt | |
| try: | |
| model = genai.GenerativeModel(model_main) | |
| res = model.generate_content(full_prompt) | |
| img_part = next((part for part in res.candidates[0].content.parts if getattr(part, "inline_data", None)), None) | |
| if img_part: | |
| return Image.open(io.BytesIO(img_part.inline_data.data)).convert("RGB") | |
| return placeholder_img() | |
| except Exception: | |
| return placeholder_img() | |
| # --- Chart Generation System --- | |
| class ChartSpecification: | |
| def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"): | |
| self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col | |
| self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme | |
| def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict: | |
| enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist() | |
| enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols}) | |
| return enhanced_ctx | |
| class ChartGenerator: | |
| def __init__(self, llm, df: pd.DataFrame): | |
| self.llm = llm; self.df = df | |
| self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}}) | |
| def generate_chart_spec(self, description: str) -> ChartSpecification: | |
| spec_prompt = f""" | |
| You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification. | |
| **Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)} | |
| **Chart Request:** {description} | |
| **Return a JSON specification with these exact fields:** | |
| {{ | |
| "chart_type": "bar|pie|line|scatter|hist", "title": "Professional chart title", "x_col": "column_name_for_x_axis", | |
| "y_col": "column_name_for_y_axis_or_null", "agg_method": "sum|mean|count|max|min|null", "top_n": "number_for_top_n_filtering_or_null" | |
| }} | |
| Return only the JSON specification, no additional text. | |
| """ | |
| try: | |
| response = self.llm.invoke(spec_prompt).content.strip() | |
| if response.startswith("```json"): response = response[7:-3] | |
| elif response.startswith("```"): response = response[3:-3] | |
| spec_dict = json.loads(response) | |
| valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']] | |
| filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys} | |
| return ChartSpecification(**filtered_dict) | |
| except Exception as e: | |
| logging.error(f"Spec generation failed: {e}. Using fallback.") | |
| return self._create_fallback_spec(description) | |
| def _create_fallback_spec(self, description: str) -> ChartSpecification: | |
| numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns'] | |
| ctype = "bar" | |
| for t in ["pie", "line", "scatter", "hist"]: | |
| if t in description.lower(): ctype = t | |
| x = categorical_cols[0] if categorical_cols else self.df.columns[0] | |
| y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None) | |
| return ChartSpecification(ctype, description, x, y) | |
| def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool: | |
| try: | |
| plot_data = prepare_plot_data(spec, df) | |
| fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default') | |
| if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45) | |
| elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal') | |
| elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) | |
| elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3) | |
| elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3) | |
| ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout() | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close() | |
| return True | |
| except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False | |
| def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series: | |
| if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}") | |
| if spec.chart_type in ["bar", "pie"]: | |
| if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10) | |
| grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum') | |
| return grouped.nlargest(spec.top_n or 10) | |
| elif spec.chart_type == "line": return df.set_index(spec.x_col)[spec.y_col].sort_index() | |
| elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna() | |
| elif spec.chart_type == "hist": return df[spec.x_col].dropna() | |
| return df[spec.x_col] | |
| # --- Animation & Video Generation --- | |
| def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str: | |
| plot_data = prepare_plot_data(spec, df) | |
| frames = max(10, int(dur * fps)) | |
| fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100) | |
| plt.tight_layout(pad=3.0) | |
| ctype = spec.chart_type | |
| if ctype == "pie": | |
| wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%') | |
| ax.set_title(spec.title); ax.axis('equal') | |
| def init(): [w.set_alpha(0) for w in wedges]; return wedges | |
| def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges | |
| elif ctype == "bar": | |
| bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4") | |
| ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1) | |
| ax.set_title(spec.title); plt.xticks(rotation=45, ha="right") | |
| def init(): return bars | |
| def update(i): | |
| for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1))) | |
| return bars | |
| elif ctype == "scatter": | |
| scat = ax.scatter([], [], alpha=0.7) | |
| x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1] | |
| ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max()) | |
| ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col) | |
| def init(): scat.set_offsets(np.empty((0, 2))); return [scat] | |
| def update(i): | |
| k = max(1, int(len(x_full) * (i / (frames - 1)))) | |
| scat.set_offsets(plot_data.iloc[:k].values); return [scat] | |
| elif ctype == "hist": | |
| _, _, patches = ax.hist(plot_data, bins=20, alpha=0) | |
| ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency") | |
| def init(): [p.set_alpha(0) for p in patches]; return patches | |
| def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches | |
| else: # line | |
| line, = ax.plot([], [], lw=2) | |
| plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data | |
| x_full, y_full = plot_data.index, plot_data.values | |
| ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1) | |
| ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col) | |
| def init(): line.set_data([], []); return [line] | |
| def update(i): | |
| k = max(2, int(len(x_full) * (i / (frames - 1)))) | |
| line.set_data(x_full[:k], y_full[:k]); return [line] | |
| anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps) | |
| anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144) | |
| plt.close(fig) | |
| return str(out) | |
| def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str: | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT)) | |
| total_frames = max(1, int(dur * fps)) | |
| for i in range(total_frames): | |
| alpha = i / (total_frames - 1) if total_frames > 1 else 1.0 | |
| frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0) | |
| video_writer.write(frame) | |
| video_writer.release() | |
| return str(out) | |
| def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str: | |
| try: | |
| llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1) | |
| chart_generator = ChartGenerator(llm, df) | |
| chart_spec = chart_generator.generate_chart_spec(desc) | |
| return animate_chart(chart_spec, df, dur, out) | |
| except Exception as e: | |
| logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file: | |
| temp_png = Path(temp_png_file.name) | |
| llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1) | |
| chart_generator = ChartGenerator(llm, df) | |
| chart_spec = chart_generator.generate_chart_spec(desc) | |
| if execute_chart_spec(chart_spec, df, temp_png): | |
| img = cv2.imread(str(temp_png)); os.unlink(temp_png) | |
| img_resized = cv2.resize(img, (WIDTH, HEIGHT)) | |
| return animate_image_fade(img_resized, dur, out) | |
| else: | |
| img = generate_image_from_prompt(f"A professional business chart showing {desc}") | |
| img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR) | |
| return animate_image_fade(img_cv, dur, out) | |
| def concat_media(file_paths: List[str], output_path: Path): | |
| valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100] | |
| if not valid_paths: raise ValueError("No valid media files to concatenate.") | |
| if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return | |
| list_file = output_path.with_suffix(".txt") | |
| with open(list_file, 'w') as f: | |
| for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n") | |
| cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True, text=True) | |
| finally: | |
| list_file.unlink(missing_ok=True) | |
| # --- Main Business Logic Functions for Flask --- | |
| # ADD THIS NEW HELPER FUNCTION SOMEWHERE NEAR THE TOP OF THE FILE | |
| def sanitize_for_firebase_key(text: str) -> str: | |
| """Replaces Firebase-forbidden characters in a string with underscores.""" | |
| forbidden_chars = ['.', '$', '#', '[', ']', '/'] | |
| for char in forbidden_chars: | |
| text = text.replace(char, '_') | |
| return text | |
| # REPLACE THE OLD generate_report_draft WITH THIS CORRECTED VERSION | |
| def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket): | |
| """ | |
| Enhanced autonomous data analysis function that intelligently analyzes any dataset | |
| and generates comprehensive, domain-appropriate reports with contextual visualizations. | |
| Maintains backward compatibility with existing function signature and outputs. | |
| """ | |
| logging.info(f"Generating enhanced autonomous report draft for project {project_id}") | |
| # Load data safely (existing functionality preserved) | |
| df = load_dataframe_safely(buf, name) | |
| # Initialize LLM (existing setup preserved) | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1) | |
| # Enhanced autonomous data analysis | |
| try: | |
| # Stage 1: Intelligent Data Classification and Deep Analysis | |
| autonomous_context = perform_autonomous_data_analysis(df, ctx, name) | |
| # Stage 2: Generate Enhanced Report with Intelligent Narrative | |
| enhanced_report = generate_intelligent_report(llm, autonomous_context) | |
| # Stage 3: Smart Chart Generation | |
| chart_urls = generate_autonomous_charts(llm, df, enhanced_report, uid, project_id, bucket) | |
| # Preserve original output structure | |
| return {"raw_md": enhanced_report, "chartUrls": chart_urls} | |
| except Exception as e: | |
| logging.error(f"Enhanced analysis failed, falling back to original: {str(e)}") | |
| # Fallback to original logic if enhancement fails | |
| return generate_original_report(df, llm, ctx, uid, project_id, bucket) | |
| def perform_autonomous_data_analysis(df: pd.DataFrame, user_ctx: str, filename: str) -> Dict[str, Any]: | |
| """ | |
| Performs comprehensive autonomous analysis of the dataset to understand its nature, | |
| domain, and analytical potential. | |
| """ | |
| logging.info("Performing autonomous data analysis...") | |
| # Basic data profiling | |
| basic_info = { | |
| "shape": df.shape, | |
| "columns": list(df.columns), | |
| "dtypes": df.dtypes.to_dict(), | |
| "filename": filename, | |
| "user_context": user_ctx | |
| } | |
| # Intelligent domain classification | |
| domain_analysis = classify_dataset_domain(df, filename) | |
| # Advanced statistical analysis | |
| statistical_profile = generate_statistical_profile(df) | |
| # Relationship discovery | |
| relationships = discover_data_relationships(df) | |
| # Temporal analysis if applicable | |
| temporal_insights = analyze_temporal_patterns(df) | |
| # Data quality assessment | |
| quality_metrics = assess_data_quality(df) | |
| # Business context inference | |
| business_context = infer_business_context(df, domain_analysis) | |
| return { | |
| "basic_info": basic_info, | |
| "domain": domain_analysis, | |
| "statistics": statistical_profile, | |
| "relationships": relationships, | |
| "temporal": temporal_insights, | |
| "quality": quality_metrics, | |
| "business_context": business_context, | |
| "analysis_complexity": determine_analysis_complexity(df, domain_analysis) | |
| } | |
| def classify_dataset_domain(df: pd.DataFrame, filename: str) -> Dict[str, Any]: | |
| """ | |
| Intelligently classifies the dataset domain based on column patterns, data types, | |
| and semantic analysis. | |
| """ | |
| domain_indicators = { | |
| "financial": ["amount", "price", "cost", "revenue", "profit", "transaction", "payment", "invoice"], | |
| "survey": ["rating", "satisfaction", "response", "score", "survey", "feedback", "opinion"], | |
| "scientific": ["measurement", "experiment", "test", "sample", "observation", "hypothesis", "variable"], | |
| "marketing": ["campaign", "click", "conversion", "customer", "lead", "acquisition", "retention"], | |
| "operational": ["process", "time", "duration", "status", "workflow", "performance", "efficiency"], | |
| "sales": ["order", "product", "quantity", "sales", "customer", "deal", "pipeline"], | |
| "hr": ["employee", "salary", "department", "performance", "training", "recruitment"], | |
| "healthcare": ["patient", "diagnosis", "treatment", "medical", "health", "symptom", "medication"] | |
| } | |
| # Analyze column names for domain indicators | |
| columns_lower = [col.lower() for col in df.columns] | |
| domain_scores = {} | |
| for domain, keywords in domain_indicators.items(): | |
| score = sum(1 for col in columns_lower for keyword in keywords if keyword in col) | |
| domain_scores[domain] = score | |
| # Filename analysis | |
| filename_lower = filename.lower() | |
| for domain, keywords in domain_indicators.items(): | |
| if any(keyword in filename_lower for keyword in keywords): | |
| domain_scores[domain] = domain_scores.get(domain, 0) + 2 | |
| # Data type analysis | |
| numeric_ratio = len(df.select_dtypes(include=[np.number]).columns) / len(df.columns) | |
| categorical_ratio = len(df.select_dtypes(include=['object']).columns) / len(df.columns) | |
| # Determine primary domain | |
| primary_domain = max(domain_scores, key=domain_scores.get) if domain_scores else "general" | |
| return { | |
| "primary_domain": primary_domain, | |
| "domain_confidence": domain_scores.get(primary_domain, 0), | |
| "domain_scores": domain_scores, | |
| "data_characteristics": { | |
| "numeric_ratio": numeric_ratio, | |
| "categorical_ratio": categorical_ratio, | |
| "is_time_series": detect_time_series(df), | |
| "is_transactional": detect_transactional_data(df), | |
| "is_experimental": detect_experimental_data(df) | |
| } | |
| } | |
| def generate_statistical_profile(df: pd.DataFrame) -> Dict[str, Any]: | |
| """ | |
| Generates comprehensive statistical profile of the dataset. | |
| """ | |
| profile = { | |
| "summary_stats": {}, | |
| "correlations": {}, | |
| "distributions": {}, | |
| "outliers": {}, | |
| "missing_data": {} | |
| } | |
| # Summary statistics for numeric columns | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| if len(numeric_cols) > 0: | |
| profile["summary_stats"] = df[numeric_cols].describe().to_dict() | |
| # Correlation analysis | |
| if len(numeric_cols) > 1: | |
| corr_matrix = df[numeric_cols].corr() | |
| # Find strong correlations | |
| strong_corrs = [] | |
| for i in range(len(corr_matrix.columns)): | |
| for j in range(i+1, len(corr_matrix.columns)): | |
| corr_val = corr_matrix.iloc[i, j] | |
| if abs(corr_val) > 0.7: # Strong correlation threshold | |
| strong_corrs.append({ | |
| "var1": corr_matrix.columns[i], | |
| "var2": corr_matrix.columns[j], | |
| "correlation": corr_val | |
| }) | |
| profile["correlations"] = {"strong_correlations": strong_corrs} | |
| # Categorical analysis | |
| categorical_cols = df.select_dtypes(include=['object']).columns | |
| if len(categorical_cols) > 0: | |
| profile["categorical_analysis"] = {} | |
| for col in categorical_cols: | |
| profile["categorical_analysis"][col] = { | |
| "unique_count": df[col].nunique(), | |
| "top_values": df[col].value_counts().head(5).to_dict() | |
| } | |
| # Missing data analysis | |
| missing_data = df.isnull().sum() | |
| profile["missing_data"] = { | |
| "columns_with_missing": missing_data[missing_data > 0].to_dict(), | |
| "total_missing_percentage": (df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100 | |
| } | |
| return profile | |
| def discover_data_relationships(df: pd.DataFrame) -> Dict[str, Any]: | |
| """ | |
| Discovers meaningful relationships and patterns in the data. | |
| """ | |
| relationships = { | |
| "key_relationships": [], | |
| "patterns": [], | |
| "anomalies": [] | |
| } | |
| # Identify potential key relationships | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| if len(numeric_cols) > 1: | |
| # Find interesting relationships | |
| for col1 in numeric_cols: | |
| for col2 in numeric_cols: | |
| if col1 != col2: | |
| correlation = df[col1].corr(df[col2]) | |
| if abs(correlation) > 0.5: # Moderate to strong correlation | |
| relationships["key_relationships"].append({ | |
| "variable1": col1, | |
| "variable2": col2, | |
| "relationship_strength": correlation, | |
| "relationship_type": "positive" if correlation > 0 else "negative" | |
| }) | |
| # Identify patterns in categorical data | |
| categorical_cols = df.select_dtypes(include=['object']).columns | |
| for col in categorical_cols: | |
| if df[col].nunique() < 20: # Reasonable number of categories | |
| value_counts = df[col].value_counts() | |
| if len(value_counts) > 0: | |
| relationships["patterns"].append({ | |
| "column": col, | |
| "pattern_type": "categorical_distribution", | |
| "dominant_category": value_counts.index[0], | |
| "dominance_percentage": (value_counts.iloc[0] / len(df)) * 100 | |
| }) | |
| return relationships | |
| def analyze_temporal_patterns(df: pd.DataFrame) -> Dict[str, Any]: | |
| """ | |
| Analyzes temporal patterns if time-based columns are detected. | |
| """ | |
| temporal_insights = {"has_temporal_data": False} | |
| # Detect date/time columns | |
| date_columns = [] | |
| for col in df.columns: | |
| if df[col].dtype == 'datetime64[ns]' or 'date' in col.lower() or 'time' in col.lower(): | |
| try: | |
| pd.to_datetime(df[col]) | |
| date_columns.append(col) | |
| except: | |
| continue | |
| if date_columns: | |
| temporal_insights["has_temporal_data"] = True | |
| temporal_insights["date_columns"] = date_columns | |
| # Analyze temporal patterns for the first date column | |
| primary_date_col = date_columns[0] | |
| df_temp = df.copy() | |
| df_temp[primary_date_col] = pd.to_datetime(df_temp[primary_date_col]) | |
| temporal_insights["temporal_analysis"] = { | |
| "date_range": { | |
| "start": df_temp[primary_date_col].min().strftime('%Y-%m-%d'), | |
| "end": df_temp[primary_date_col].max().strftime('%Y-%m-%d') | |
| }, | |
| "time_span_days": (df_temp[primary_date_col].max() - df_temp[primary_date_col].min()).days, | |
| "frequency": detect_temporal_frequency(df_temp[primary_date_col]) | |
| } | |
| return temporal_insights | |
| def assess_data_quality(df: pd.DataFrame) -> Dict[str, Any]: | |
| """ | |
| Assesses data quality and identifies potential issues. | |
| """ | |
| quality_metrics = { | |
| "overall_quality_score": 0, | |
| "quality_issues": [], | |
| "data_completeness": 0, | |
| "data_consistency": {} | |
| } | |
| # Completeness assessment | |
| completeness = (1 - df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100 | |
| quality_metrics["data_completeness"] = completeness | |
| # Identify quality issues | |
| if completeness < 95: | |
| quality_metrics["quality_issues"].append("Missing data detected") | |
| # Check for duplicates | |
| duplicate_rows = df.duplicated().sum() | |
| if duplicate_rows > 0: | |
| quality_metrics["quality_issues"].append(f"{duplicate_rows} duplicate rows found") | |
| # Check for inconsistent data types | |
| for col in df.columns: | |
| if df[col].dtype == 'object': | |
| if df[col].str.isnumeric().any() and not df[col].str.isnumeric().all(): | |
| quality_metrics["quality_issues"].append(f"Inconsistent data types in {col}") | |
| # Calculate overall quality score | |
| base_score = 100 | |
| base_score -= (100 - completeness) * 0.5 # Penalize missing data | |
| base_score -= len(quality_metrics["quality_issues"]) * 5 # Penalize each quality issue | |
| quality_metrics["overall_quality_score"] = max(0, base_score) | |
| return quality_metrics | |
| def infer_business_context(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Infers business context and potential use cases based on the data characteristics. | |
| """ | |
| domain = domain_analysis["primary_domain"] | |
| context_mapping = { | |
| "financial": { | |
| "key_metrics": ["Revenue", "Profit", "Cost", "ROI"], | |
| "typical_analyses": ["Trend analysis", "Profitability analysis", "Budget vs actual"], | |
| "stakeholders": ["CFO", "Finance team", "Executive leadership"] | |
| }, | |
| "survey": { | |
| "key_metrics": ["Satisfaction scores", "Response rates", "Sentiment"], | |
| "typical_analyses": ["Satisfaction analysis", "Demographic breakdown", "Correlation analysis"], | |
| "stakeholders": ["Marketing team", "Product managers", "Customer success"] | |
| }, | |
| "scientific": { | |
| "key_metrics": ["Statistical significance", "Effect size", "Confidence intervals"], | |
| "typical_analyses": ["Hypothesis testing", "Regression analysis", "Experimental validation"], | |
| "stakeholders": ["Researchers", "Scientists", "Academic community"] | |
| }, | |
| "marketing": { | |
| "key_metrics": ["Conversion rates", "Customer acquisition cost", "Campaign ROI"], | |
| "typical_analyses": ["Campaign performance", "Customer segmentation", "Attribution analysis"], | |
| "stakeholders": ["Marketing team", "CMO", "Sales team"] | |
| } | |
| } | |
| return context_mapping.get(domain, { | |
| "key_metrics": ["Performance indicators", "Trends", "Patterns"], | |
| "typical_analyses": ["Descriptive analysis", "Trend identification", "Pattern recognition"], | |
| "stakeholders": ["Business stakeholders", "Decision makers"] | |
| }) | |
| def generate_intelligent_report(llm, autonomous_context: Dict[str, Any]) -> str: | |
| """ | |
| Generates an intelligent, domain-appropriate report with organic storytelling. | |
| """ | |
| # Create truly autonomous prompt that lets AI decide everything | |
| enhanced_prompt = f""" | |
| You are a world-class data analyst who has just been handed this dataset to analyze. Look at the data characteristics and tell me the most compelling story you can find. | |
| **DATASET CONTEXT:** | |
| {json.dumps(autonomous_context, indent=2)} | |
| **YOUR MISSION:** | |
| Analyze this data like you would if a CEO walked into your office and said "I need to understand what this data is telling us." Write a report that would make them say "This is exactly what I needed to know." | |
| **GUIDELINES:** | |
| - Don't follow a rigid structure - let the data guide your narrative | |
| - Choose your own headings and sections based on what the data reveals | |
| - Write like you're presenting findings to someone who needs to make important decisions | |
| - Include specific numbers and insights that matter | |
| - Insert chart recommendations like: `<generate_chart: "chart_type | description">` | |
| - Valid chart types: bar, pie, line, scatter, hist, box, heatmap | |
| - Only recommend charts that truly support your narrative | |
| **FORGET TEMPLATES - TELL THE STORY:** | |
| What's the most interesting, important, or surprising thing this data reveals? Start there and build your entire report around that central insight. Make it compelling, make it actionable, make it memorable. | |
| Be the data analyst who gets promoted because they don't just present data - they reveal insights that drive business decisions. | |
| """ | |
| # Removed - no longer needed since we're letting AI decide everything organically | |
| def generate_autonomous_charts(llm, df: pd.DataFrame, report_md: str, uid: str, project_id: str, bucket) -> Dict[str, str]: | |
| """ | |
| Generates charts autonomously based on the report content and data characteristics. | |
| """ | |
| # Extract chart descriptions from the enhanced report | |
| chart_descs = extract_chart_tags(report_md)[:MAX_CHARTS] | |
| chart_urls = {} | |
| if not chart_descs: | |
| # If no charts specified, generate intelligent defaults | |
| chart_descs = generate_intelligent_chart_suggestions(df, llm) | |
| chart_generator = ChartGenerator(llm, df) | |
| for desc in chart_descs: | |
| try: | |
| # Create a safe key for Firebase | |
| safe_desc = sanitize_for_firebase_key(desc) | |
| # Replace chart tags in markdown | |
| report_md = report_md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">') | |
| report_md = report_md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">') | |
| # Generate chart | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| img_path = Path(temp_file.name) | |
| try: | |
| chart_spec = chart_generator.generate_chart_spec(desc) | |
| if execute_chart_spec(chart_spec, df, img_path): | |
| blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(img_path)) | |
| chart_urls[safe_desc] = blob.public_url | |
| logging.info(f"Generated autonomous chart: {safe_desc}") | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| except Exception as e: | |
| logging.error(f"Failed to generate chart '{desc}': {str(e)}") | |
| continue | |
| return chart_urls | |
| def generate_intelligent_chart_suggestions(df: pd.DataFrame, llm) -> List[str]: | |
| """ | |
| Generates intelligent chart suggestions based on data characteristics. | |
| """ | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| categorical_cols = df.select_dtypes(include=['object']).columns | |
| suggestions = [] | |
| # Time series chart if temporal data exists | |
| if detect_time_series(df): | |
| suggestions.append("line | Time series trend analysis | Show temporal patterns") | |
| # Distribution chart for numeric data | |
| if len(numeric_cols) > 0: | |
| main_numeric = numeric_cols[0] | |
| suggestions.append(f"hist | Distribution of {main_numeric} | Understand data distribution") | |
| # Correlation analysis if multiple numeric columns | |
| if len(numeric_cols) > 1: | |
| suggestions.append("scatter | Correlation analysis | Identify relationships between variables") | |
| # Categorical breakdown | |
| if len(categorical_cols) > 0: | |
| main_categorical = categorical_cols[0] | |
| suggestions.append(f"bar | {main_categorical} breakdown | Show categorical distribution") | |
| return suggestions[:MAX_CHARTS] | |
| # Helper functions (preserve existing functionality) | |
| def detect_time_series(df: pd.DataFrame) -> bool: | |
| """Detect if dataset contains time series data.""" | |
| for col in df.columns: | |
| if 'date' in col.lower() or 'time' in col.lower(): | |
| return True | |
| try: | |
| pd.to_datetime(df[col]) | |
| return True | |
| except: | |
| continue | |
| return False | |
| def detect_transactional_data(df: pd.DataFrame) -> bool: | |
| """Detect if dataset contains transactional data.""" | |
| transaction_indicators = ['transaction', 'payment', 'order', 'invoice', 'amount', 'quantity'] | |
| columns_lower = [col.lower() for col in df.columns] | |
| return any(indicator in col for col in columns_lower for indicator in transaction_indicators) | |
| def detect_experimental_data(df: pd.DataFrame) -> bool: | |
| """Detect if dataset contains experimental data.""" | |
| experimental_indicators = ['test', 'experiment', 'trial', 'group', 'treatment', 'control'] | |
| columns_lower = [col.lower() for col in df.columns] | |
| return any(indicator in col for col in columns_lower for indicator in experimental_indicators) | |
| def detect_temporal_frequency(date_series: pd.Series) -> str: | |
| """Detect the frequency of temporal data.""" | |
| if len(date_series) < 2: | |
| return "insufficient_data" | |
| # Calculate time differences | |
| time_diffs = date_series.sort_values().diff().dropna() | |
| median_diff = time_diffs.median() | |
| if median_diff <= pd.Timedelta(days=1): | |
| return "daily" | |
| elif median_diff <= pd.Timedelta(days=7): | |
| return "weekly" | |
| elif median_diff <= pd.Timedelta(days=31): | |
| return "monthly" | |
| else: | |
| return "irregular" | |
| def determine_analysis_complexity(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> str: | |
| """Determine the complexity level of analysis required.""" | |
| complexity_factors = 0 | |
| # Data size factor | |
| if len(df) > 10000: | |
| complexity_factors += 1 | |
| if len(df.columns) > 20: | |
| complexity_factors += 1 | |
| # Data type diversity | |
| if len(df.select_dtypes(include=[np.number]).columns) > 5: | |
| complexity_factors += 1 | |
| if len(df.select_dtypes(include=['object']).columns) > 5: | |
| complexity_factors += 1 | |
| # Domain complexity | |
| if domain_analysis["primary_domain"] in ["scientific", "financial"]: | |
| complexity_factors += 1 | |
| if complexity_factors >= 3: | |
| return "high" | |
| elif complexity_factors >= 2: | |
| return "medium" | |
| else: | |
| return "low" | |
| def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_id: str, bucket) -> Dict[str, str]: | |
| """ | |
| Fallback to original report generation logic if enhanced version fails. | |
| """ | |
| logging.info("Using fallback report generation") | |
| # Original logic preserved | |
| ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx} | |
| enhanced_ctx = enhance_data_context(df, ctx_dict) | |
| report_prompt = f""" | |
| You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report. | |
| **Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)} | |
| **Instructions:** | |
| 1. **Executive Summary**: Start with a high-level summary of key findings. | |
| 2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag. | |
| 3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`. | |
| Valid chart types: bar, pie, line, scatter, hist. | |
| Generate insights that would be valuable to C-level executives. | |
| """ | |
| md = llm.invoke(report_prompt).content | |
| chart_descs = extract_chart_tags(md)[:MAX_CHARTS] | |
| chart_urls = {} | |
| chart_generator = ChartGenerator(llm, df) | |
| for desc in chart_descs: | |
| safe_desc = sanitize_for_firebase_key(desc) | |
| md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">') | |
| md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">') | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| img_path = Path(temp_file.name) | |
| try: | |
| chart_spec = chart_generator.generate_chart_spec(desc) | |
| if execute_chart_spec(chart_spec, df, img_path): | |
| blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(img_path)) | |
| chart_urls[safe_desc] = blob.public_url | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| return {"raw_md": md, "chartUrls": chart_urls} | |
| def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str: | |
| """ | |
| Generates a basic fallback report when enhanced generation fails. | |
| """ | |
| basic_info = autonomous_context["basic_info"] | |
| domain = autonomous_context["domain"]["primary_domain"] | |
| return f""" | |
| # What This Data Reveals | |
| Looking at this {domain} dataset with {basic_info['shape'][0]} records, there are several key insights worth highlighting. | |
| ## The Numbers Tell a Story | |
| This dataset contains {basic_info['shape'][1]} different variables, suggesting a comprehensive view of the underlying processes or behaviors being measured. | |
| <generate_chart: "bar | Data overview showing key metrics"> | |
| ## What You Should Know | |
| The data structure and patterns suggest this is worth deeper investigation. The variety of data types and relationships indicate multiple analytical opportunities. | |
| ## Next Steps | |
| Based on this initial analysis, I recommend diving deeper into the specific patterns and relationships within the data to unlock more actionable insights. | |
| *Note: This is a simplified analysis. Enhanced storytelling temporarily unavailable.* | |
| """ | |
| def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket): | |
| logging.info(f"Generating single chart '{description}' for project {project_id}") | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1) | |
| chart_generator = ChartGenerator(llm, df) | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| img_path = Path(temp_file.name) | |
| try: | |
| chart_spec = chart_generator.generate_chart_spec(description) | |
| if execute_chart_spec(chart_spec, df, img_path): | |
| blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(img_path)) | |
| logging.info(f"Uploaded single chart to {blob.public_url}") | |
| return blob.public_url | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| return None | |
| def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket): | |
| logging.info(f"Generating video for project {project_id} with voice {voice_model}") | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.2) | |
| story_prompt = f"Based on the following report, create a script for a {VIDEO_SCENES}-scene video. Each scene must be separated by '[SCENE_BREAK]' and contain narration and one chart tag. Report: {raw_md}" | |
| script = llm.invoke(story_prompt).content | |
| scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()] | |
| video_parts, audio_parts, temps = [], [], [] | |
| for sc in scenes: | |
| descs, narrative = extract_chart_tags(sc), clean_narration(sc) | |
| audio_bytes = deepgram_tts(narrative, voice_model) | |
| mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3" | |
| if audio_bytes: | |
| mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3)) | |
| if dur <= 0.1: dur = 5.0 | |
| else: | |
| dur = 5.0; generate_silence_mp3(dur, mp3) | |
| audio_parts.append(str(mp3)); temps.append(mp3) | |
| mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4" | |
| if descs: safe_chart(descs[0], df, dur, mp4) | |
| else: | |
| img = generate_image_from_prompt(narrative) | |
| img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR) | |
| animate_image_fade(img_cv, dur, mp4) | |
| video_parts.append(str(mp4)); temps.append(mp4) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \ | |
| tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \ | |
| tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid: | |
| silent_vid_path = Path(temp_vid.name) | |
| audio_mix_path = Path(temp_aud.name) | |
| final_vid_path = Path(final_vid.name) | |
| concat_media(video_parts, silent_vid_path) | |
| concat_media(audio_parts, audio_mix_path) | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path), | |
| "-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac", | |
| "-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)], | |
| check=True, capture_output=True, | |
| ) | |
| blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4" | |
| blob = bucket.blob(blob_name) | |
| blob.upload_from_filename(str(final_vid_path)) | |
| logging.info(f"Uploaded video to {blob.public_url}") | |
| for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]: | |
| if os.path.exists(p): os.unlink(p) | |
| return blob.public_url | |
| return None |