KoreAI-API / sozo_gen.py
rairo's picture
Update sozo_gen.py
bb892b9 verified
raw
history blame
21.1 kB
# sozo_gen.py
import os
import re
import json
import logging
import uuid
import io
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter
from PIL import Image
import cv2
import inspect
import tempfile
import subprocess
from typing import Dict
from langchain_google_genai import ChatGoogleGenerativeAI
from google import genai
import requests
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s')
FPS, WIDTH, HEIGHT = 24, 1280, 720
MAX_CHARTS, VIDEO_SCENES = 5, 5
# --- Gemini API Initialization ---
# CORRECTED: Use the correct environment variable name and remove the deprecated configure call.
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
raise ValueError("GOOGLE_API_KEY environment variable not set.")
# REMOVED: genai.configure(api_key=API_KEY) - This is deprecated. The library now uses the environment variable automatically.
# --- Helper Functions ---
def load_dataframe_safely(buf, name: str):
ext = Path(name).suffix.lower()
df = (pd.read_excel if ext in (".xlsx", ".xls") else pd.read_csv)(buf)
df.columns = df.columns.astype(str).str.strip()
df = df.dropna(how="all")
if df.empty or len(df.columns) == 0: raise ValueError("No usable data found")
return df
def deepgram_tts(txt: str, voice_model: str):
DG_KEY = os.getenv("DEEPGRAM_API_KEY")
if not DG_KEY or not txt: return None
txt = re.sub(r"[^\w\s.,!?;:-]", "", txt)[:1000]
try:
r = requests.post("https://api.deepgram.com/v1/speak", params={"model": voice_model}, headers={"Authorization": f"Token {DG_KEY}", "Content-Type": "application/json"}, json={"text": txt}, timeout=30)
r.raise_for_status()
return r.content
except Exception as e:
logging.error(f"Deepgram TTS failed: {e}")
return None
def generate_silence_mp3(duration: float, out: Path):
subprocess.run([ "ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=r=44100:cl=mono", "-t", f"{duration:.3f}", "-q:a", "9", str(out)], check=True, capture_output=True)
def audio_duration(path: str) -> float:
try:
res = subprocess.run([ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=nw=1:nk=1", path], text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
return float(res.stdout.strip())
except Exception: return 5.0
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
def clean_narration(txt: str) -> str:
txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt)
phrases_to_remove = [r"as you can see in the chart", r"this chart shows", r"the chart illustrates", r"in this visual", r"this graph displays"]
for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
return re.sub(r"\s{2,}", " ", txt).strip()
def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
def generate_image_from_prompt(prompt: str) -> Image.Image:
model_main = "gemini-2.0-flash-exp-image-generation"; model_fallback = "gemini-2.0-flash-preview-image-generation"
full_prompt = "A clean business-presentation illustration: " + prompt
def fetch(model_name):
try:
model = genai.GenerativeModel(model_name)
res = model.generate_content(full_prompt)
for part in res.candidates[0].content.parts:
if getattr(part, "inline_data", None):
return Image.open(io.BytesIO(part.inline_data.data)).convert("RGB")
return None
except Exception:
return None
try:
img = fetch(model_main) or fetch(model_fallback)
return img if img else placeholder_img()
except Exception: return placeholder_img()
# --- Chart Generation System ---
class ChartSpecification:
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
enhanced_ctx = ctx_dict.copy(); numeric_cols = df.select_dtypes(include=['number']).columns.tolist(); categorical_cols = df.select_dtypes(exclude=['number']).columns.tolist()
enhanced_ctx.update({"numeric_columns": numeric_cols, "categorical_columns": categorical_cols})
return enhanced_ctx
class ChartGenerator:
def __init__(self, llm, df: pd.DataFrame):
self.llm = llm; self.df = df
self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
def generate_chart_spec(self, description: str) -> ChartSpecification:
spec_prompt = f"""
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
**Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
**Chart Request:** {description}
**Return a JSON specification with these exact fields:**
{{
"chart_type": "bar|pie|line|scatter|hist", "title": "Professional chart title", "x_col": "column_name_for_x_axis",
"y_col": "column_name_for_y_axis_or_null", "agg_method": "sum|mean|count|max|min|null", "top_n": "number_for_top_n_filtering_or_null"
}}
Return only the JSON specification, no additional text.
"""
try:
response = self.llm.invoke(spec_prompt).content.strip()
if response.startswith("```json"): response = response[7:-3]
elif response.startswith("```"): response = response[3:-3]
spec_dict = json.loads(response)
valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values() if p.name not in ['reasoning', 'filter_condition', 'color_scheme']]
filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys}
return ChartSpecification(**filtered_dict)
except Exception as e:
logging.error(f"Spec generation failed: {e}. Using fallback.")
return self._create_fallback_spec(description)
def _create_fallback_spec(self, description: str) -> ChartSpecification:
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
ctype = "bar"
for t in ["pie", "line", "scatter", "hist"]:
if t in description.lower(): ctype = t
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
return ChartSpecification(ctype, description, x, y)
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
try:
plot_data = prepare_plot_data(spec, df)
fig, ax = plt.subplots(figsize=(12, 8)); plt.style.use('default')
if spec.chart_type == "bar": ax.bar(plot_data.index.astype(str), plot_data.values, color='#2E86AB', alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.tick_params(axis='x', rotation=45)
elif spec.chart_type == "pie": ax.pie(plot_data.values, labels=plot_data.index, autopct='%1.1f%%', startangle=90); ax.axis('equal')
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
return True
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
if spec.chart_type in ["bar", "pie"]:
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
return grouped.nlargest(spec.top_n or 10)
elif spec.chart_type == "line": return df.set_index(spec.x_col)[spec.y_col].sort_index()
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
elif spec.chart_type == "hist": return df[spec.x_col].dropna()
return df[spec.x_col]
# --- Animation & Video Generation ---
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
plot_data = prepare_plot_data(spec, df)
frames = max(10, int(dur * fps))
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
plt.tight_layout(pad=3.0)
ctype = spec.chart_type
if ctype == "pie":
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
ax.set_title(spec.title); ax.axis('equal')
def init(): [w.set_alpha(0) for w in wedges]; return wedges
def update(i): [w.set_alpha(i / (frames - 1)) for w in wedges]; return wedges
elif ctype == "bar":
bars = ax.bar(plot_data.index.astype(str), np.zeros_like(plot_data.values, dtype=float), color="#1f77b4")
ax.set_ylim(0, plot_data.max() * 1.1 if not pd.isna(plot_data.max()) and plot_data.max() > 0 else 1)
ax.set_title(spec.title); plt.xticks(rotation=45, ha="right")
def init(): return bars
def update(i):
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
return bars
elif ctype == "scatter":
scat = ax.scatter([], [], alpha=0.7)
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
def init(): scat.set_offsets(np.empty((0, 2))); return [scat]
def update(i):
k = max(1, int(len(x_full) * (i / (frames - 1))))
scat.set_offsets(plot_data.iloc[:k].values); return [scat]
elif ctype == "hist":
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
def init(): [p.set_alpha(0) for p in patches]; return patches
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
else: # line
line, = ax.plot([], [], lw=2)
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
x_full, y_full = plot_data.index, plot_data.values
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
def init(): line.set_data([], []); return [line]
def update(i):
k = max(2, int(len(x_full) * (i / (frames - 1))))
line.set_data(x_full[:k], y_full[:k]); return [line]
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
plt.close(fig)
return str(out)
def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) -> str:
fourcc = cv2.VideoWriter_fourcc(*'mp4v'); video_writer = cv2.VideoWriter(str(out), fourcc, fps, (WIDTH, HEIGHT))
total_frames = max(1, int(dur * fps))
for i in range(total_frames):
alpha = i / (total_frames - 1) if total_frames > 1 else 1.0
frame = cv2.addWeighted(img, alpha, np.zeros_like(img), 1 - alpha, 0)
video_writer.write(frame)
video_writer.release()
return str(out)
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
try:
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
chart_spec = chart_generator.generate_chart_spec(desc)
return animate_chart(chart_spec, df, dur, out)
except Exception as e:
logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
temp_png = Path(temp_png_file.name)
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, temp_png):
img = cv2.imread(str(temp_png)); os.unlink(temp_png)
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
return animate_image_fade(img_resized, dur, out)
else:
img = generate_image_from_prompt(f"A professional business chart showing {desc}")
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
return animate_image_fade(img_cv, dur, out)
def concat_media(file_paths: List[str], output_path: Path):
valid_paths = [p for p in file_paths if Path(p).exists() and Path(p).stat().st_size > 100]
if not valid_paths: raise ValueError("No valid media files to concatenate.")
if len(valid_paths) == 1: import shutil; shutil.copy2(valid_paths[0], str(output_path)); return
list_file = output_path.with_suffix(".txt")
with open(list_file, 'w') as f:
for path in valid_paths: f.write(f"file '{Path(path).resolve()}'\n")
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(list_file), "-c", "copy", str(output_path)]
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
finally:
list_file.unlink(missing_ok=True)
# --- Main Business Logic Functions for Flask ---
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
logging.info(f"Generating report draft for project {project_id}")
df = load_dataframe_safely(buf, name)
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
enhanced_ctx = enhance_data_context(df, ctx_dict)
report_prompt = f"""
You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
**Instructions:**
1. **Executive Summary**: Start with a high-level summary of key findings.
2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
Valid chart types: bar, pie, line, scatter, hist.
Generate insights that would be valuable to C-level executives.
"""
md = llm.invoke(report_prompt).content
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
chart_urls = {}
chart_generator = ChartGenerator(llm, df)
for desc in chart_descs:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(desc)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
chart_urls[desc] = blob.public_url
logging.info(f"Uploaded chart '{desc}' to {blob.public_url}")
finally:
os.unlink(img_path)
return {"raw_md": md, "chartUrls": chart_urls}
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
logging.info(f"Generating single chart '{description}' for project {project_id}")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
chart_generator = ChartGenerator(llm, df)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img_path = Path(temp_file.name)
try:
chart_spec = chart_generator.generate_chart_spec(description)
if execute_chart_spec(chart_spec, df, img_path):
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(img_path))
logging.info(f"Uploaded single chart to {blob.public_url}")
return blob.public_url
finally:
os.unlink(img_path)
return None
def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.2)
story_prompt = f"Based on the following report, create a script for a {VIDEO_SCENES}-scene video. Each scene must be separated by '[SCENE_BREAK]' and contain narration and one chart tag. Report: {raw_md}"
script = llm.invoke(story_prompt).content
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
video_parts, audio_parts, temps = [], [], []
for sc in scenes:
descs, narrative = extract_chart_tags(sc), clean_narration(sc)
audio_bytes = deepgram_tts(narrative, voice_model)
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
if audio_bytes:
mp3.write_bytes(audio_bytes); dur = audio_duration(str(mp3))
if dur <= 0.1: dur = 5.0
else:
dur = 5.0; generate_silence_mp3(dur, mp3)
audio_parts.append(str(mp3)); temps.append(mp3)
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
if descs: safe_chart(descs[0], df, dur, mp4)
else:
img = generate_image_from_prompt(narrative)
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
animate_image_fade(img_cv, dur, mp4)
video_parts.append(str(mp4)); temps.append(mp4)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as final_vid:
silent_vid_path = Path(temp_vid.name)
audio_mix_path = Path(temp_aud.name)
final_vid_path = Path(final_vid.name)
concat_media(video_parts, silent_vid_path)
concat_media(audio_parts, audio_mix_path)
subprocess.run(
["ffmpeg", "-y", "-i", str(silent_vid_path), "-i", str(audio_mix_path),
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-c:a", "aac",
"-map", "0:v:0", "-map", "1:a:0", "-shortest", str(final_vid_path)],
check=True, capture_output=True,
)
blob_name = f"sozo_projects/{uid}/{project_id}/video.mp4"
blob = bucket.blob(blob_name)
blob.upload_from_filename(str(final_vid_path))
logging.info(f"Uploaded video to {blob.public_url}")
for p in temps + [silent_vid_path, audio_mix_path, final_vid_path]:
if os.path.exists(p): os.unlink(p)
return blob.public_url
return None