Spaces:
Sleeping
Sleeping
Update sozo_gen.py
Browse files- sozo_gen.py +215 -427
sozo_gen.py
CHANGED
|
@@ -13,6 +13,8 @@ import matplotlib
|
|
| 13 |
matplotlib.use("Agg")
|
| 14 |
import matplotlib.pyplot as plt
|
| 15 |
from matplotlib.animation import FuncAnimation, FFMpegWriter
|
|
|
|
|
|
|
| 16 |
from PIL import Image
|
| 17 |
import cv2
|
| 18 |
import inspect
|
|
@@ -28,11 +30,14 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%
|
|
| 28 |
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 29 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 30 |
|
| 31 |
-
# ---
|
| 32 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 33 |
if not API_KEY:
|
| 34 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
# --- Helper Functions ---
|
| 37 |
def load_dataframe_safely(buf, name: str):
|
| 38 |
ext = Path(name).suffix.lower()
|
|
@@ -63,13 +68,17 @@ def audio_duration(path: str) -> float:
|
|
| 63 |
return float(res.stdout.strip())
|
| 64 |
except Exception: return 5.0
|
| 65 |
|
|
|
|
| 66 |
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
|
|
|
| 67 |
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
|
|
|
|
|
|
|
| 68 |
|
| 69 |
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
|
| 70 |
def clean_narration(txt: str) -> str:
|
| 71 |
-
txt = TAG_RE.sub("", txt); txt = re_scene.sub("", txt)
|
| 72 |
-
phrases_to_remove = [r"chart tag", r"chart_tag", r"narration"]
|
| 73 |
for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
|
| 74 |
txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
|
| 75 |
return re.sub(r"\s{2,}", " ", txt).strip()
|
|
@@ -89,10 +98,66 @@ def generate_image_from_prompt(prompt: str) -> Image.Image:
|
|
| 89 |
except Exception:
|
| 90 |
return placeholder_img()
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# --- Chart Generation System ---
|
|
|
|
| 93 |
class ChartSpecification:
|
| 94 |
-
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
| 95 |
-
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col
|
| 96 |
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
|
| 97 |
|
| 98 |
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
|
|
@@ -107,16 +172,22 @@ class ChartGenerator:
|
|
| 107 |
|
| 108 |
def generate_chart_spec(self, description: str) -> ChartSpecification:
|
| 109 |
safe_ctx = json_serializable(self.enhanced_ctx)
|
|
|
|
| 110 |
spec_prompt = f"""
|
| 111 |
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
|
| 112 |
**Dataset Info:** {json.dumps(safe_ctx, indent=2)}
|
| 113 |
**Chart Request:** {description}
|
| 114 |
**Return a JSON specification with these exact fields:**
|
| 115 |
{{
|
| 116 |
-
"chart_type": "bar|pie|line|scatter|hist",
|
| 117 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
}}
|
| 119 |
-
Return only the JSON specification, no additional text.
|
| 120 |
"""
|
| 121 |
try:
|
| 122 |
response = self.llm.invoke(spec_prompt).content.strip()
|
|
@@ -133,12 +204,13 @@ class ChartGenerator:
|
|
| 133 |
def _create_fallback_spec(self, description: str) -> ChartSpecification:
|
| 134 |
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
|
| 135 |
ctype = "bar"
|
| 136 |
-
for t in ["pie", "line", "scatter", "hist"]:
|
| 137 |
if t in description.lower(): ctype = t
|
| 138 |
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
|
| 139 |
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
|
| 140 |
return ChartSpecification(ctype, description, x, y)
|
| 141 |
|
|
|
|
| 142 |
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
| 143 |
try:
|
| 144 |
plot_data = prepare_plot_data(spec, df)
|
|
@@ -148,29 +220,47 @@ def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path:
|
|
| 148 |
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 149 |
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 150 |
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
|
| 152 |
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
|
| 153 |
return True
|
| 154 |
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
| 158 |
if spec.chart_type in ["bar", "pie"]:
|
| 159 |
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
|
| 160 |
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
|
| 161 |
return grouped.nlargest(spec.top_n or 10)
|
| 162 |
-
elif spec.chart_type
|
| 163 |
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
|
|
|
|
|
|
|
|
|
|
| 164 |
elif spec.chart_type == "hist": return df[spec.x_col].dropna()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
return df[spec.x_col]
|
| 166 |
|
| 167 |
# --- Animation & Video Generation ---
|
|
|
|
| 168 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 169 |
plot_data = prepare_plot_data(spec, df)
|
| 170 |
frames = max(10, int(dur * fps))
|
| 171 |
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
|
| 172 |
plt.tight_layout(pad=3.0)
|
| 173 |
ctype = spec.chart_type
|
|
|
|
| 174 |
if ctype == "pie":
|
| 175 |
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 176 |
ax.set_title(spec.title); ax.axis('equal')
|
|
@@ -185,29 +275,79 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 185 |
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
|
| 186 |
return bars
|
| 187 |
elif ctype == "scatter":
|
| 188 |
-
scat = ax.scatter([], [], alpha=0.7)
|
| 189 |
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
|
| 191 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
def update(i):
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
elif ctype == "hist":
|
| 197 |
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
|
| 198 |
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
|
| 199 |
def init(): [p.set_alpha(0) for p in patches]; return patches
|
| 200 |
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
|
| 204 |
x_full, y_full = plot_data.index, plot_data.values
|
| 205 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 206 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 207 |
-
def init():
|
|
|
|
|
|
|
|
|
|
| 208 |
def update(i):
|
| 209 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 210 |
-
line.set_data(x_full[:k], y_full[:k])
|
|
|
|
|
|
|
|
|
|
| 211 |
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
|
| 212 |
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
|
| 213 |
plt.close(fig)
|
|
@@ -258,9 +398,11 @@ def concat_media(file_paths: List[str], output_path: Path):
|
|
| 258 |
finally:
|
| 259 |
list_file.unlink(missing_ok=True)
|
| 260 |
|
| 261 |
-
# --- Main Business Logic Functions
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
# ADD THIS NEW HELPER FUNCTION SOMEWHERE NEAR THE TOP OF THE FILE
|
| 264 |
def sanitize_for_firebase_key(text: str) -> str:
|
| 265 |
"""Replaces Firebase-forbidden characters in a string with underscores."""
|
| 266 |
forbidden_chars = ['.', '$', '#', '[', ']', '/']
|
|
@@ -268,10 +410,6 @@ def sanitize_for_firebase_key(text: str) -> str:
|
|
| 268 |
text = text.replace(char, '_')
|
| 269 |
return text
|
| 270 |
|
| 271 |
-
# REPLACE THE OLD generate_report_draft WITH THIS CORRECTED VERSION
|
| 272 |
-
from scipy import stats
|
| 273 |
-
import re
|
| 274 |
-
|
| 275 |
def analyze_data_intelligence(df: pd.DataFrame, ctx_dict: Dict) -> Dict[str, Any]:
|
| 276 |
"""
|
| 277 |
Autonomous data intelligence system that classifies domain,
|
|
@@ -483,7 +621,7 @@ def create_autonomous_prompt(df: pd.DataFrame, enhanced_ctx: Dict, intelligence:
|
|
| 483 |
|
| 484 |
**CHART INTEGRATION:**
|
| 485 |
Insert charts using: `<generate_chart: "chart_type | compelling description that advances the story">`
|
| 486 |
-
Available types: bar, pie, line, scatter, hist
|
| 487 |
|
| 488 |
Transform this data into a story that decision-makers can't stop reading."""
|
| 489 |
|
|
@@ -508,397 +646,38 @@ def generate_chart_strategy(intelligence: Dict) -> str:
|
|
| 508 |
|
| 509 |
# Add specific guidance based on data characteristics
|
| 510 |
if structure['is_timeseries']:
|
| 511 |
-
base_strategy += " Leverage time-series visualizations to show trends and patterns over time."
|
| 512 |
|
| 513 |
if 'correlations' in opportunities:
|
| 514 |
-
base_strategy += " Include correlation visualizations to reveal hidden relationships."
|
| 515 |
|
| 516 |
if 'segmentation' in opportunities:
|
| 517 |
base_strategy += " Use segmented charts to highlight different groups or categories."
|
| 518 |
|
| 519 |
return base_strategy
|
| 520 |
|
| 521 |
-
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict[str, Any]:
|
| 522 |
-
"""Enhanced context generation with AI-driven analysis"""
|
| 523 |
-
|
| 524 |
-
# Get autonomous intelligence analysis
|
| 525 |
-
intelligence = analyze_data_intelligence(df, ctx_dict)
|
| 526 |
-
|
| 527 |
-
# Original context enhancement
|
| 528 |
-
enhanced = ctx_dict.copy()
|
| 529 |
-
|
| 530 |
-
# Add statistical context
|
| 531 |
-
if not df.empty:
|
| 532 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 533 |
-
if len(numeric_cols) > 0:
|
| 534 |
-
key_metrics = {}
|
| 535 |
-
for col in numeric_cols[:3]: # Top 3 numeric columns
|
| 536 |
-
try:
|
| 537 |
-
mean_val = df[col].mean()
|
| 538 |
-
std_val = df[col].std()
|
| 539 |
-
key_metrics[col] = {
|
| 540 |
-
'mean': float(mean_val) if pd.notna(mean_val) else 0.0,
|
| 541 |
-
'std': float(std_val) if pd.notna(std_val) else 0.0
|
| 542 |
-
}
|
| 543 |
-
except:
|
| 544 |
-
key_metrics[col] = {'mean': 0.0, 'std': 0.0}
|
| 545 |
-
|
| 546 |
-
enhanced['statistical_summary'] = {
|
| 547 |
-
'numeric_columns': int(len(numeric_cols)),
|
| 548 |
-
'total_records': int(len(df)),
|
| 549 |
-
'missing_data_percentage': float((df.isnull().sum().sum() / (len(df) * len(df.columns))) * 100),
|
| 550 |
-
'key_metrics': key_metrics
|
| 551 |
-
}
|
| 552 |
-
|
| 553 |
-
# Add categorical context
|
| 554 |
-
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
|
| 555 |
-
if len(categorical_cols) > 0:
|
| 556 |
-
unique_values = {}
|
| 557 |
-
for col in categorical_cols[:3]:
|
| 558 |
-
try:
|
| 559 |
-
unique_values[col] = int(df[col].nunique())
|
| 560 |
-
except:
|
| 561 |
-
unique_values[col] = 0
|
| 562 |
-
|
| 563 |
-
enhanced['categorical_summary'] = {
|
| 564 |
-
'categorical_columns': int(len(categorical_cols)),
|
| 565 |
-
'unique_values': unique_values
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
-
# Merge with intelligence analysis
|
| 569 |
-
enhanced['ai_intelligence'] = intelligence
|
| 570 |
-
|
| 571 |
-
return enhanced
|
| 572 |
-
|
| 573 |
-
def create_chart_safe_context(enhanced_ctx: Dict) -> Dict:
|
| 574 |
-
"""
|
| 575 |
-
Create a chart-generator-safe version of enhanced context
|
| 576 |
-
by ensuring all values are JSON serializable
|
| 577 |
-
"""
|
| 578 |
-
def make_json_safe(obj):
|
| 579 |
-
if isinstance(obj, bool):
|
| 580 |
-
return bool(obj)
|
| 581 |
-
elif isinstance(obj, (np.integer, np.floating)):
|
| 582 |
-
return float(obj)
|
| 583 |
-
elif isinstance(obj, np.ndarray):
|
| 584 |
-
return obj.tolist()
|
| 585 |
-
elif isinstance(obj, np.bool_):
|
| 586 |
-
return bool(obj)
|
| 587 |
-
elif isinstance(obj, dict):
|
| 588 |
-
return {k: make_json_safe(v) for k, v in obj.items()}
|
| 589 |
-
elif isinstance(obj, (list, tuple)):
|
| 590 |
-
return [make_json_safe(item) for item in obj]
|
| 591 |
-
elif pd.isna(obj):
|
| 592 |
-
return None
|
| 593 |
-
elif hasattr(obj, 'item'): # numpy scalars
|
| 594 |
-
return obj.item()
|
| 595 |
-
else:
|
| 596 |
-
return obj
|
| 597 |
-
|
| 598 |
-
return make_json_safe(enhanced_ctx)
|
| 599 |
-
|
| 600 |
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
|
| 601 |
-
|
| 602 |
-
Enhanced autonomous report generation with intelligent narrative creation
|
| 603 |
-
"""
|
| 604 |
logging.info(f"Generating autonomous report draft for project {project_id}")
|
| 605 |
|
| 606 |
df = load_dataframe_safely(buf, name)
|
| 607 |
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
|
| 608 |
|
| 609 |
-
# Build enhanced context with AI intelligence
|
| 610 |
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
|
| 611 |
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
| 612 |
-
|
| 613 |
-
# Get AI intelligence analysis
|
| 614 |
intelligence = analyze_data_intelligence(df, ctx_dict)
|
| 615 |
-
|
| 616 |
-
# Generate autonomous prompt
|
| 617 |
report_prompt = create_autonomous_prompt(df, enhanced_ctx, intelligence)
|
| 618 |
-
|
| 619 |
-
# Generate the report
|
| 620 |
md = llm.invoke(report_prompt).content
|
| 621 |
|
| 622 |
-
# Extract and process charts
|
| 623 |
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 624 |
chart_urls = {}
|
| 625 |
-
|
| 626 |
-
# Create a chart-safe context
|
| 627 |
-
chart_safe_ctx = create_chart_safe_context(enhanced_ctx)
|
| 628 |
-
|
| 629 |
-
# Try to pass the safe context to ChartGenerator
|
| 630 |
-
try:
|
| 631 |
-
chart_generator = ChartGenerator(llm, df, chart_safe_ctx)
|
| 632 |
-
except TypeError:
|
| 633 |
-
# Fallback: if ChartGenerator doesn't accept enhanced_ctx parameter
|
| 634 |
-
chart_generator = ChartGenerator(llm, df)
|
| 635 |
-
# If it has an enhanced_ctx attribute, set it safely
|
| 636 |
-
if hasattr(chart_generator, 'enhanced_ctx'):
|
| 637 |
-
chart_generator.enhanced_ctx = chart_safe_ctx
|
| 638 |
-
|
| 639 |
-
for desc in chart_descs:
|
| 640 |
-
# Create a safe key for Firebase
|
| 641 |
-
safe_desc = sanitize_for_firebase_key(desc)
|
| 642 |
-
|
| 643 |
-
# Replace the original description in the markdown with the safe one
|
| 644 |
-
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 645 |
-
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">') # Handle no quotes case
|
| 646 |
-
|
| 647 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 648 |
-
img_path = Path(temp_file.name)
|
| 649 |
-
try:
|
| 650 |
-
chart_spec = chart_generator.generate_chart_spec(desc) # Still generate spec from original desc
|
| 651 |
-
if execute_chart_spec(chart_spec, df, img_path):
|
| 652 |
-
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 653 |
-
blob = bucket.blob(blob_name)
|
| 654 |
-
blob.upload_from_filename(str(img_path))
|
| 655 |
-
|
| 656 |
-
# Use the safe key in the dictionary
|
| 657 |
-
chart_urls[safe_desc] = blob.public_url
|
| 658 |
-
logging.info(f"Uploaded chart '{desc}' to {blob.public_url} with safe key '{safe_desc}'")
|
| 659 |
-
finally:
|
| 660 |
-
if os.path.exists(img_path):
|
| 661 |
-
os.unlink(img_path)
|
| 662 |
-
|
| 663 |
-
return {"raw_md": md, "chartUrls": chart_urls}
|
| 664 |
-
|
| 665 |
-
# Additional helper functions for the autonomous system
|
| 666 |
-
|
| 667 |
-
def detect_data_relationships(df: pd.DataFrame) -> Dict[str, Any]:
|
| 668 |
-
"""Detect relationships and patterns in the data"""
|
| 669 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 670 |
-
relationships = {}
|
| 671 |
-
|
| 672 |
-
if len(numeric_cols) > 1:
|
| 673 |
-
corr_matrix = df[numeric_cols].corr()
|
| 674 |
-
# Find strong correlations (> 0.7 or < -0.7)
|
| 675 |
-
strong_correlations = []
|
| 676 |
-
for i in range(len(corr_matrix.columns)):
|
| 677 |
-
for j in range(i+1, len(corr_matrix.columns)):
|
| 678 |
-
corr_val = corr_matrix.iloc[i, j]
|
| 679 |
-
if abs(corr_val) > 0.7:
|
| 680 |
-
strong_correlations.append({
|
| 681 |
-
'var1': corr_matrix.columns[i],
|
| 682 |
-
'var2': corr_matrix.columns[j],
|
| 683 |
-
'correlation': corr_val
|
| 684 |
-
})
|
| 685 |
-
relationships['strong_correlations'] = strong_correlations
|
| 686 |
-
|
| 687 |
-
return relationships
|
| 688 |
-
|
| 689 |
-
def identify_key_metrics(df: pd.DataFrame, domain: str) -> List[str]:
|
| 690 |
-
"""Identify the most important metrics based on domain and data characteristics"""
|
| 691 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 692 |
-
|
| 693 |
-
domain_priorities = {
|
| 694 |
-
'financial': ['revenue', 'profit', 'cost', 'amount', 'price', 'margin'],
|
| 695 |
-
'survey': ['rating', 'score', 'satisfaction', 'response'],
|
| 696 |
-
'marketing': ['conversion', 'click', 'impression', 'engagement'],
|
| 697 |
-
'operational': ['efficiency', 'utilization', 'throughput', 'performance']
|
| 698 |
-
}
|
| 699 |
-
|
| 700 |
-
priorities = domain_priorities.get(domain, [])
|
| 701 |
-
key_metrics = []
|
| 702 |
-
|
| 703 |
-
# Match column names with domain priorities
|
| 704 |
-
for col in numeric_cols:
|
| 705 |
-
col_lower = col.lower()
|
| 706 |
-
for priority in priorities:
|
| 707 |
-
if priority in col_lower:
|
| 708 |
-
key_metrics.append(col)
|
| 709 |
-
break
|
| 710 |
-
|
| 711 |
-
# If no matches, use columns with highest variance (most interesting)
|
| 712 |
-
if not key_metrics and numeric_cols:
|
| 713 |
-
variances = df[numeric_cols].var().sort_values(ascending=False)
|
| 714 |
-
key_metrics = variances.head(3).index.tolist()
|
| 715 |
-
|
| 716 |
-
return key_metrics[:5] # Return top 5 key metrics
|
| 717 |
-
# Removed - no longer needed since we're letting AI decide everything organically
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
def generate_autonomous_charts(llm, df: pd.DataFrame, report_md: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
|
| 721 |
-
"""
|
| 722 |
-
Generates charts autonomously based on the report content and data characteristics.
|
| 723 |
-
"""
|
| 724 |
-
# Extract chart descriptions from the enhanced report
|
| 725 |
-
chart_descs = extract_chart_tags(report_md)[:MAX_CHARTS]
|
| 726 |
-
chart_urls = {}
|
| 727 |
-
|
| 728 |
-
if not chart_descs:
|
| 729 |
-
# If no charts specified, generate intelligent defaults
|
| 730 |
-
chart_descs = generate_intelligent_chart_suggestions(df, llm)
|
| 731 |
-
|
| 732 |
chart_generator = ChartGenerator(llm, df)
|
| 733 |
|
| 734 |
-
for desc in chart_descs:
|
| 735 |
-
try:
|
| 736 |
-
# Create a safe key for Firebase
|
| 737 |
-
safe_desc = sanitize_for_firebase_key(desc)
|
| 738 |
-
|
| 739 |
-
# Replace chart tags in markdown
|
| 740 |
-
report_md = report_md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 741 |
-
report_md = report_md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 742 |
-
|
| 743 |
-
# Generate chart
|
| 744 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 745 |
-
img_path = Path(temp_file.name)
|
| 746 |
-
try:
|
| 747 |
-
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 748 |
-
if execute_chart_spec(chart_spec, df, img_path):
|
| 749 |
-
blob_name = f"sozo_projects/{uid}/{project_id}/charts/{uuid.uuid4().hex}.png"
|
| 750 |
-
blob = bucket.blob(blob_name)
|
| 751 |
-
blob.upload_from_filename(str(img_path))
|
| 752 |
-
|
| 753 |
-
chart_urls[safe_desc] = blob.public_url
|
| 754 |
-
logging.info(f"Generated autonomous chart: {safe_desc}")
|
| 755 |
-
finally:
|
| 756 |
-
if os.path.exists(img_path):
|
| 757 |
-
os.unlink(img_path)
|
| 758 |
-
|
| 759 |
-
except Exception as e:
|
| 760 |
-
logging.error(f"Failed to generate chart '{desc}': {str(e)}")
|
| 761 |
-
continue
|
| 762 |
-
|
| 763 |
-
return chart_urls
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
def generate_intelligent_chart_suggestions(df: pd.DataFrame, llm) -> List[str]:
|
| 767 |
-
"""
|
| 768 |
-
Generates intelligent chart suggestions based on data characteristics.
|
| 769 |
-
"""
|
| 770 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 771 |
-
categorical_cols = df.select_dtypes(include=['object']).columns
|
| 772 |
-
|
| 773 |
-
suggestions = []
|
| 774 |
-
|
| 775 |
-
# Time series chart if temporal data exists
|
| 776 |
-
if detect_time_series(df):
|
| 777 |
-
suggestions.append("line | Time series trend analysis | Show temporal patterns")
|
| 778 |
-
|
| 779 |
-
# Distribution chart for numeric data
|
| 780 |
-
if len(numeric_cols) > 0:
|
| 781 |
-
main_numeric = numeric_cols[0]
|
| 782 |
-
suggestions.append(f"hist | Distribution of {main_numeric} | Understand data distribution")
|
| 783 |
-
|
| 784 |
-
# Correlation analysis if multiple numeric columns
|
| 785 |
-
if len(numeric_cols) > 1:
|
| 786 |
-
suggestions.append("scatter | Correlation analysis | Identify relationships between variables")
|
| 787 |
-
|
| 788 |
-
# Categorical breakdown
|
| 789 |
-
if len(categorical_cols) > 0:
|
| 790 |
-
main_categorical = categorical_cols[0]
|
| 791 |
-
suggestions.append(f"bar | {main_categorical} breakdown | Show categorical distribution")
|
| 792 |
-
|
| 793 |
-
return suggestions[:MAX_CHARTS]
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
# Helper functions (preserve existing functionality)
|
| 797 |
-
def detect_time_series(df: pd.DataFrame) -> bool:
|
| 798 |
-
"""Detect if dataset contains time series data."""
|
| 799 |
-
for col in df.columns:
|
| 800 |
-
if 'date' in col.lower() or 'time' in col.lower():
|
| 801 |
-
return True
|
| 802 |
-
try:
|
| 803 |
-
pd.to_datetime(df[col])
|
| 804 |
-
return True
|
| 805 |
-
except:
|
| 806 |
-
continue
|
| 807 |
-
return False
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
def detect_transactional_data(df: pd.DataFrame) -> bool:
|
| 811 |
-
"""Detect if dataset contains transactional data."""
|
| 812 |
-
transaction_indicators = ['transaction', 'payment', 'order', 'invoice', 'amount', 'quantity']
|
| 813 |
-
columns_lower = [col.lower() for col in df.columns]
|
| 814 |
-
return any(indicator in col for col in columns_lower for indicator in transaction_indicators)
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
def detect_experimental_data(df: pd.DataFrame) -> bool:
|
| 818 |
-
"""Detect if dataset contains experimental data."""
|
| 819 |
-
experimental_indicators = ['test', 'experiment', 'trial', 'group', 'treatment', 'control']
|
| 820 |
-
columns_lower = [col.lower() for col in df.columns]
|
| 821 |
-
return any(indicator in col for col in columns_lower for indicator in experimental_indicators)
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
def detect_temporal_frequency(date_series: pd.Series) -> str:
|
| 825 |
-
"""Detect the frequency of temporal data."""
|
| 826 |
-
if len(date_series) < 2:
|
| 827 |
-
return "insufficient_data"
|
| 828 |
-
|
| 829 |
-
# Calculate time differences
|
| 830 |
-
time_diffs = date_series.sort_values().diff().dropna()
|
| 831 |
-
median_diff = time_diffs.median()
|
| 832 |
-
|
| 833 |
-
if median_diff <= pd.Timedelta(days=1):
|
| 834 |
-
return "daily"
|
| 835 |
-
elif median_diff <= pd.Timedelta(days=7):
|
| 836 |
-
return "weekly"
|
| 837 |
-
elif median_diff <= pd.Timedelta(days=31):
|
| 838 |
-
return "monthly"
|
| 839 |
-
else:
|
| 840 |
-
return "irregular"
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
def determine_analysis_complexity(df: pd.DataFrame, domain_analysis: Dict[str, Any]) -> str:
|
| 844 |
-
"""Determine the complexity level of analysis required."""
|
| 845 |
-
complexity_factors = 0
|
| 846 |
-
|
| 847 |
-
# Data size factor
|
| 848 |
-
if len(df) > 10000:
|
| 849 |
-
complexity_factors += 1
|
| 850 |
-
if len(df.columns) > 20:
|
| 851 |
-
complexity_factors += 1
|
| 852 |
-
|
| 853 |
-
# Data type diversity
|
| 854 |
-
if len(df.select_dtypes(include=[np.number]).columns) > 5:
|
| 855 |
-
complexity_factors += 1
|
| 856 |
-
if len(df.select_dtypes(include=['object']).columns) > 5:
|
| 857 |
-
complexity_factors += 1
|
| 858 |
-
|
| 859 |
-
# Domain complexity
|
| 860 |
-
if domain_analysis["primary_domain"] in ["scientific", "financial"]:
|
| 861 |
-
complexity_factors += 1
|
| 862 |
-
|
| 863 |
-
if complexity_factors >= 3:
|
| 864 |
-
return "high"
|
| 865 |
-
elif complexity_factors >= 2:
|
| 866 |
-
return "medium"
|
| 867 |
-
else:
|
| 868 |
-
return "low"
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_id: str, bucket) -> Dict[str, str]:
|
| 872 |
-
"""
|
| 873 |
-
Fallback to original report generation logic if enhanced version fails.
|
| 874 |
-
"""
|
| 875 |
-
logging.info("Using fallback report generation")
|
| 876 |
-
|
| 877 |
-
# Original logic preserved
|
| 878 |
-
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
|
| 879 |
-
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
| 880 |
-
|
| 881 |
-
report_prompt = f"""
|
| 882 |
-
You are a senior data analyst and business intelligence expert. Analyze the provided dataset and write a comprehensive executive-level Markdown report.
|
| 883 |
-
**Dataset Analysis Context:** {json.dumps(enhanced_ctx, indent=2)}
|
| 884 |
-
**Instructions:**
|
| 885 |
-
1. **Executive Summary**: Start with a high-level summary of key findings.
|
| 886 |
-
2. **Key Insights**: Provide 3-5 key insights, each with its own chart tag.
|
| 887 |
-
3. **Visual Support**: Insert chart tags like: `<generate_chart: "chart_type | specific description">`.
|
| 888 |
-
Valid chart types: bar, pie, line, scatter, hist.
|
| 889 |
-
Generate insights that would be valuable to C-level executives.
|
| 890 |
-
"""
|
| 891 |
-
|
| 892 |
-
md = llm.invoke(report_prompt).content
|
| 893 |
-
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 894 |
-
chart_urls = {}
|
| 895 |
-
chart_generator = ChartGenerator(llm, df)
|
| 896 |
-
|
| 897 |
for desc in chart_descs:
|
| 898 |
safe_desc = sanitize_for_firebase_key(desc)
|
| 899 |
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 900 |
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 901 |
-
|
| 902 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 903 |
img_path = Path(temp_file.name)
|
| 904 |
try:
|
|
@@ -908,41 +687,12 @@ def generate_original_report(df: pd.DataFrame, llm, ctx: str, uid: str, project_
|
|
| 908 |
blob = bucket.blob(blob_name)
|
| 909 |
blob.upload_from_filename(str(img_path))
|
| 910 |
chart_urls[safe_desc] = blob.public_url
|
|
|
|
| 911 |
finally:
|
| 912 |
if os.path.exists(img_path):
|
| 913 |
os.unlink(img_path)
|
| 914 |
-
|
| 915 |
-
return {"raw_md": md, "chartUrls": chart_urls}
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
def generate_fallback_report(autonomous_context: Dict[str, Any]) -> str:
|
| 919 |
-
"""
|
| 920 |
-
Generates a basic fallback report when enhanced generation fails.
|
| 921 |
-
"""
|
| 922 |
-
basic_info = autonomous_context["basic_info"]
|
| 923 |
-
domain = autonomous_context["domain"]["primary_domain"]
|
| 924 |
|
| 925 |
-
return
|
| 926 |
-
# What This Data Reveals
|
| 927 |
-
|
| 928 |
-
Looking at this {domain} dataset with {basic_info['shape'][0]} records, there are several key insights worth highlighting.
|
| 929 |
-
|
| 930 |
-
## The Numbers Tell a Story
|
| 931 |
-
|
| 932 |
-
This dataset contains {basic_info['shape'][1]} different variables, suggesting a comprehensive view of the underlying processes or behaviors being measured.
|
| 933 |
-
|
| 934 |
-
<generate_chart: "bar | Data overview showing key metrics">
|
| 935 |
-
|
| 936 |
-
## What You Should Know
|
| 937 |
-
|
| 938 |
-
The data structure and patterns suggest this is worth deeper investigation. The variety of data types and relationships indicate multiple analytical opportunities.
|
| 939 |
-
|
| 940 |
-
## Next Steps
|
| 941 |
-
|
| 942 |
-
Based on this initial analysis, I recommend diving deeper into the specific patterns and relationships within the data to unlock more actionable insights.
|
| 943 |
-
|
| 944 |
-
*Note: This is a simplified analysis. Enhanced storytelling temporarily unavailable.*
|
| 945 |
-
"""
|
| 946 |
|
| 947 |
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
|
| 948 |
logging.info(f"Generating single chart '{description}' for project {project_id}")
|
|
@@ -963,15 +713,36 @@ def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_
|
|
| 963 |
os.unlink(img_path)
|
| 964 |
return None
|
| 965 |
|
|
|
|
| 966 |
def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
|
| 967 |
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
|
| 968 |
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
|
| 969 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
script = llm.invoke(story_prompt).content
|
| 971 |
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
|
| 972 |
video_parts, audio_parts, temps = [], [], []
|
| 973 |
-
|
| 974 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
audio_bytes = deepgram_tts(narrative, voice_model)
|
| 976 |
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
|
| 977 |
if audio_bytes:
|
|
@@ -980,13 +751,30 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project
|
|
| 980 |
else:
|
| 981 |
dur = 5.0; generate_silence_mp3(dur, mp3)
|
| 982 |
audio_parts.append(str(mp3)); temps.append(mp3)
|
|
|
|
| 983 |
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
|
| 984 |
-
|
| 985 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 986 |
img = generate_image_from_prompt(narrative)
|
| 987 |
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
|
| 988 |
animate_image_fade(img_cv, dur, mp4)
|
| 989 |
-
|
| 990 |
|
| 991 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
|
| 992 |
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
|
|
|
|
| 13 |
matplotlib.use("Agg")
|
| 14 |
import matplotlib.pyplot as plt
|
| 15 |
from matplotlib.animation import FuncAnimation, FFMpegWriter
|
| 16 |
+
import seaborn as sns # Added for heatmaps
|
| 17 |
+
from scipy import stats # Added for scatterplot regression
|
| 18 |
from PIL import Image
|
| 19 |
import cv2
|
| 20 |
import inspect
|
|
|
|
| 30 |
FPS, WIDTH, HEIGHT = 24, 1280, 720
|
| 31 |
MAX_CHARTS, VIDEO_SCENES = 5, 5
|
| 32 |
|
| 33 |
+
# --- API Initialization ---
|
| 34 |
API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 35 |
if not API_KEY:
|
| 36 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
| 37 |
|
| 38 |
+
# NEW: Pexels API Key
|
| 39 |
+
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
|
| 40 |
+
|
| 41 |
# --- Helper Functions ---
|
| 42 |
def load_dataframe_safely(buf, name: str):
|
| 43 |
ext = Path(name).suffix.lower()
|
|
|
|
| 68 |
return float(res.stdout.strip())
|
| 69 |
except Exception: return 5.0
|
| 70 |
|
| 71 |
+
# UPDATED: Regex for chart tags and NEW regex for stock video tags
|
| 72 |
TAG_RE = re.compile( r'[<[]\s*generate_?chart\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 73 |
+
TAG_RE_PEXELS = re.compile( r'[<[]\s*generate_?stock_?video\s*[:=]?\s*[\"\'“”]?(?P<d>[^>\"\'”\]]+?)[\"\'“”]?\s*[>\]]', re.I, )
|
| 74 |
extract_chart_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE.finditer(t or "")) )
|
| 75 |
+
extract_pexels_tags = lambda t: list( dict.fromkeys(m.group("d").strip() for m in TAG_RE_PEXELS.finditer(t or "")) )
|
| 76 |
+
|
| 77 |
|
| 78 |
re_scene = re.compile(r"^\s*scene\s*\d+[:.\- ]*", re.I | re.M)
|
| 79 |
def clean_narration(txt: str) -> str:
|
| 80 |
+
txt = TAG_RE.sub("", txt); txt = TAG_RE_PEXELS.sub("", txt); txt = re_scene.sub("", txt)
|
| 81 |
+
phrases_to_remove = [r"chart tag", r"chart_tag", r"narration", r"stock video tag"]
|
| 82 |
for phrase in phrases_to_remove: txt = re.sub(phrase, "", txt, flags=re.IGNORECASE)
|
| 83 |
txt = re.sub(r"\s*\([^)]*\)", "", txt); txt = re.sub(r"[\*#_]", "", txt)
|
| 84 |
return re.sub(r"\s{2,}", " ", txt).strip()
|
|
|
|
| 98 |
except Exception:
|
| 99 |
return placeholder_img()
|
| 100 |
|
| 101 |
+
# NEW: Pexels video search and download function
|
| 102 |
+
def search_and_download_pexels_video(query: str, duration: float, out_path: Path) -> str:
|
| 103 |
+
if not PEXELS_API_KEY:
|
| 104 |
+
logging.warning("PEXELS_API_KEY not set. Cannot fetch stock video.")
|
| 105 |
+
return None
|
| 106 |
+
try:
|
| 107 |
+
headers = {"Authorization": PEXELS_API_KEY}
|
| 108 |
+
params = {"query": query, "per_page": 15, "orientation": "landscape"}
|
| 109 |
+
response = requests.get("https://api.pexels.com/videos/search", headers=headers, params=params, timeout=20)
|
| 110 |
+
response.raise_for_status()
|
| 111 |
+
videos = response.json().get('videos', [])
|
| 112 |
+
if not videos:
|
| 113 |
+
logging.warning(f"No Pexels videos found for query: '{query}'")
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
# Find a suitable video file (prefer HD)
|
| 117 |
+
video_to_download = None
|
| 118 |
+
for video in videos:
|
| 119 |
+
for f in video.get('video_files', []):
|
| 120 |
+
if f.get('quality') == 'hd' and f.get('width') >= 1280:
|
| 121 |
+
video_to_download = f['link']
|
| 122 |
+
break
|
| 123 |
+
if video_to_download:
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
if not video_to_download:
|
| 127 |
+
logging.warning(f"No suitable HD video file found for query: '{query}'")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
# Download to a temporary file
|
| 131 |
+
with requests.get(video_to_download, stream=True, timeout=60) as r:
|
| 132 |
+
r.raise_for_status()
|
| 133 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_dl_file:
|
| 134 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 135 |
+
temp_dl_file.write(chunk)
|
| 136 |
+
temp_dl_path = Path(temp_dl_file.name)
|
| 137 |
+
|
| 138 |
+
# Use FFmpeg to resize, crop, and trim to exact duration
|
| 139 |
+
cmd = [
|
| 140 |
+
"ffmpeg", "-y", "-i", str(temp_dl_path),
|
| 141 |
+
"-vf", f"scale={WIDTH}:{HEIGHT}:force_original_aspect_ratio=decrease,pad={WIDTH}:{HEIGHT}:(ow-iw)/2:(oh-ih)/2,setsar=1",
|
| 142 |
+
"-t", f"{duration:.3f}",
|
| 143 |
+
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-an",
|
| 144 |
+
str(out_path)
|
| 145 |
+
]
|
| 146 |
+
subprocess.run(cmd, check=True, capture_output=True)
|
| 147 |
+
temp_dl_path.unlink()
|
| 148 |
+
return str(out_path)
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logging.error(f"Pexels video processing failed for query '{query}': {e}")
|
| 152 |
+
if 'temp_dl_path' in locals() and temp_dl_path.exists():
|
| 153 |
+
temp_dl_path.unlink()
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
# --- Chart Generation System ---
|
| 157 |
+
# UPDATED: ChartSpecification to include size_col for bubble charts
|
| 158 |
class ChartSpecification:
|
| 159 |
+
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str = None, size_col: str = None, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
| 160 |
+
self.chart_type = chart_type; self.title = title; self.x_col = x_col; self.y_col = y_col; self.size_col = size_col
|
| 161 |
self.agg_method = agg_method or "sum"; self.filter_condition = filter_condition; self.top_n = top_n; self.color_scheme = color_scheme
|
| 162 |
|
| 163 |
def enhance_data_context(df: pd.DataFrame, ctx_dict: Dict) -> Dict:
|
|
|
|
| 172 |
|
| 173 |
def generate_chart_spec(self, description: str) -> ChartSpecification:
|
| 174 |
safe_ctx = json_serializable(self.enhanced_ctx)
|
| 175 |
+
# UPDATED: Prompt to include new chart types
|
| 176 |
spec_prompt = f"""
|
| 177 |
You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
|
| 178 |
**Dataset Info:** {json.dumps(safe_ctx, indent=2)}
|
| 179 |
**Chart Request:** {description}
|
| 180 |
**Return a JSON specification with these exact fields:**
|
| 181 |
{{
|
| 182 |
+
"chart_type": "bar|pie|line|scatter|hist|heatmap|area|bubble",
|
| 183 |
+
"title": "Professional chart title",
|
| 184 |
+
"x_col": "column_name_for_x_axis_or_null_for_heatmap",
|
| 185 |
+
"y_col": "column_name_for_y_axis_or_null",
|
| 186 |
+
"size_col": "column_name_for_bubble_size_or_null",
|
| 187 |
+
"agg_method": "sum|mean|count|max|min|null",
|
| 188 |
+
"top_n": "number_for_top_n_filtering_or_null"
|
| 189 |
}}
|
| 190 |
+
Return only the JSON specification, no additional text. For heatmaps, x_col and y_col can be null if it's a correlation matrix of all numeric columns.
|
| 191 |
"""
|
| 192 |
try:
|
| 193 |
response = self.llm.invoke(spec_prompt).content.strip()
|
|
|
|
| 204 |
def _create_fallback_spec(self, description: str) -> ChartSpecification:
|
| 205 |
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
|
| 206 |
ctype = "bar"
|
| 207 |
+
for t in ["pie", "line", "scatter", "hist", "heatmap", "area", "bubble"]:
|
| 208 |
if t in description.lower(): ctype = t
|
| 209 |
x = categorical_cols[0] if categorical_cols else self.df.columns[0]
|
| 210 |
y = numeric_cols[0] if numeric_cols and len(self.df.columns) > 1 else (self.df.columns[1] if len(self.df.columns) > 1 else None)
|
| 211 |
return ChartSpecification(ctype, description, x, y)
|
| 212 |
|
| 213 |
+
# UPDATED: execute_chart_spec to include new chart types
|
| 214 |
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
| 215 |
try:
|
| 216 |
plot_data = prepare_plot_data(spec, df)
|
|
|
|
| 220 |
elif spec.chart_type == "line": ax.plot(plot_data.index, plot_data.values, marker='o', linewidth=2, color='#A23B72'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 221 |
elif spec.chart_type == "scatter": ax.scatter(plot_data.iloc[:, 0], plot_data.iloc[:, 1], alpha=0.6, color='#F18F01'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 222 |
elif spec.chart_type == "hist": ax.hist(plot_data.values, bins=20, color='#C73E1D', alpha=0.7, edgecolor='black'); ax.set_xlabel(spec.x_col); ax.set_ylabel('Frequency'); ax.grid(True, alpha=0.3)
|
| 223 |
+
elif spec.chart_type == "area": ax.fill_between(plot_data.index, plot_data.values, color="#4E79A7", alpha=0.4); ax.plot(plot_data.index, plot_data.values, color="#4E79A7", alpha=0.8); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 224 |
+
elif spec.chart_type == "heatmap": sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax); plt.xticks(rotation=45, ha="right"); plt.yticks(rotation=0)
|
| 225 |
+
elif spec.chart_type == "bubble":
|
| 226 |
+
sizes = (plot_data[spec.size_col] - plot_data[spec.size_col].min() + 1) / (plot_data[spec.size_col].max() - plot_data[spec.size_col].min() + 1) * 2000 + 50
|
| 227 |
+
ax.scatter(plot_data[spec.x_col], plot_data[spec.y_col], s=sizes, alpha=0.6, color='#59A14F'); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col); ax.grid(True, alpha=0.3)
|
| 228 |
+
|
| 229 |
ax.set_title(spec.title, fontsize=14, fontweight='bold', pad=20); plt.tight_layout()
|
| 230 |
plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white'); plt.close()
|
| 231 |
return True
|
| 232 |
except Exception as e: logging.error(f"Static chart generation failed for '{spec.title}': {e}"); return False
|
| 233 |
|
| 234 |
+
# UPDATED: prepare_plot_data to handle new chart types
|
| 235 |
+
def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame):
|
| 236 |
+
if spec.chart_type not in ["heatmap"]:
|
| 237 |
+
if spec.x_col not in df.columns or (spec.y_col and spec.y_col not in df.columns): raise ValueError(f"Invalid columns in chart spec: {spec.x_col}, {spec.y_col}")
|
| 238 |
+
|
| 239 |
if spec.chart_type in ["bar", "pie"]:
|
| 240 |
if not spec.y_col: return df[spec.x_col].value_counts().nlargest(spec.top_n or 10)
|
| 241 |
grouped = df.groupby(spec.x_col)[spec.y_col].agg(spec.agg_method or 'sum')
|
| 242 |
return grouped.nlargest(spec.top_n or 10)
|
| 243 |
+
elif spec.chart_type in ["line", "area"]: return df.set_index(spec.x_col)[spec.y_col].sort_index()
|
| 244 |
elif spec.chart_type == "scatter": return df[[spec.x_col, spec.y_col]].dropna()
|
| 245 |
+
elif spec.chart_type == "bubble":
|
| 246 |
+
if not spec.size_col or spec.size_col not in df.columns: raise ValueError("Bubble chart requires a valid size_col.")
|
| 247 |
+
return df[[spec.x_col, spec.y_col, spec.size_col]].dropna()
|
| 248 |
elif spec.chart_type == "hist": return df[spec.x_col].dropna()
|
| 249 |
+
elif spec.chart_type == "heatmap":
|
| 250 |
+
numeric_cols = df.select_dtypes(include=np.number).columns
|
| 251 |
+
if not numeric_cols.any(): raise ValueError("Heatmap requires numeric columns.")
|
| 252 |
+
return df[numeric_cols].corr()
|
| 253 |
return df[spec.x_col]
|
| 254 |
|
| 255 |
# --- Animation & Video Generation ---
|
| 256 |
+
# UPDATED: animate_chart with enhanced animations and new chart types
|
| 257 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 258 |
plot_data = prepare_plot_data(spec, df)
|
| 259 |
frames = max(10, int(dur * fps))
|
| 260 |
fig, ax = plt.subplots(figsize=(WIDTH / 100, HEIGHT / 100), dpi=100)
|
| 261 |
plt.tight_layout(pad=3.0)
|
| 262 |
ctype = spec.chart_type
|
| 263 |
+
|
| 264 |
if ctype == "pie":
|
| 265 |
wedges, _, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 266 |
ax.set_title(spec.title); ax.axis('equal')
|
|
|
|
| 275 |
for b, h in zip(bars, plot_data.values): b.set_height(h * (i / (frames - 1)))
|
| 276 |
return bars
|
| 277 |
elif ctype == "scatter":
|
|
|
|
| 278 |
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
|
| 279 |
+
# Calculate regression line
|
| 280 |
+
slope, intercept, _, _, _ = stats.linregress(x_full, y_full)
|
| 281 |
+
reg_line_x = np.array([x_full.min(), x_full.max()])
|
| 282 |
+
reg_line_y = slope * reg_line_x + intercept
|
| 283 |
+
|
| 284 |
+
scat = ax.scatter([], [], alpha=0.7, color='#F18F01')
|
| 285 |
+
line, = ax.plot([], [], 'r--', lw=2) # Regression line
|
| 286 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min(), y_full.max())
|
| 287 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 288 |
+
|
| 289 |
+
def init():
|
| 290 |
+
scat.set_offsets(np.empty((0, 2)))
|
| 291 |
+
line.set_data([], [])
|
| 292 |
+
return [scat, line]
|
| 293 |
def update(i):
|
| 294 |
+
# Animate points for the first 70% of frames
|
| 295 |
+
point_frames = int(frames * 0.7)
|
| 296 |
+
if i <= point_frames:
|
| 297 |
+
k = max(1, int(len(x_full) * (i / point_frames)))
|
| 298 |
+
scat.set_offsets(plot_data.iloc[:k].values)
|
| 299 |
+
# Animate regression line for the last 30%
|
| 300 |
+
else:
|
| 301 |
+
line_frame = i - point_frames
|
| 302 |
+
line_total_frames = frames - point_frames
|
| 303 |
+
current_x = reg_line_x[0] + (reg_line_x[1] - reg_line_x[0]) * (line_frame / line_total_frames)
|
| 304 |
+
line.set_data([reg_line_x[0], current_x], [reg_line_y[0], slope * current_x + intercept])
|
| 305 |
+
return [scat, line]
|
| 306 |
elif ctype == "hist":
|
| 307 |
_, _, patches = ax.hist(plot_data, bins=20, alpha=0)
|
| 308 |
ax.set_title(spec.title); ax.set_xlabel(spec.x_col); ax.set_ylabel("Frequency")
|
| 309 |
def init(): [p.set_alpha(0) for p in patches]; return patches
|
| 310 |
def update(i): [p.set_alpha((i / (frames - 1)) * 0.7) for p in patches]; return patches
|
| 311 |
+
elif ctype == "area":
|
| 312 |
+
plot_data = plot_data.sort_index()
|
| 313 |
+
x_full, y_full = plot_data.index, plot_data.values
|
| 314 |
+
fill = ax.fill_between(x_full, np.zeros_like(y_full), color="#4E79A7", alpha=0.4)
|
| 315 |
+
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(0, y_full.max() * 1.1)
|
| 316 |
+
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 317 |
+
def init(): return [fill]
|
| 318 |
+
def update(i):
|
| 319 |
+
ax.collections.clear()
|
| 320 |
+
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 321 |
+
fill = ax.fill_between(x_full[:k], y_full[:k], color="#4E79A7", alpha=0.4)
|
| 322 |
+
return [fill]
|
| 323 |
+
elif ctype == "heatmap":
|
| 324 |
+
sns.heatmap(plot_data, annot=True, cmap="viridis", ax=ax, alpha=0)
|
| 325 |
+
ax.set_title(spec.title)
|
| 326 |
+
def init(): ax.collections[0].set_alpha(0); return [ax.collections[0]]
|
| 327 |
+
def update(i): ax.collections[0].set_alpha(i / (frames - 1)); return [ax.collections[0]]
|
| 328 |
+
elif ctype == "bubble":
|
| 329 |
+
sizes = (plot_data[spec.size_col] - plot_data[spec.size_col].min() + 1) / (plot_data[spec.size_col].max() - plot_data[spec.size_col].min() + 1) * 2000 + 50
|
| 330 |
+
scat = ax.scatter(plot_data[spec.x_col], plot_data[spec.y_col], s=sizes, alpha=0, color='#59A14F')
|
| 331 |
+
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 332 |
+
def init(): scat.set_alpha(0); return [scat]
|
| 333 |
+
def update(i): scat.set_alpha(i / (frames - 1) * 0.7); return [scat]
|
| 334 |
+
else: # line (Time Series)
|
| 335 |
+
line, = ax.plot([], [], lw=2, color='#A23B72')
|
| 336 |
+
markers, = ax.plot([], [], 'o', color='#A23B72', markersize=5) # Animated markers
|
| 337 |
plot_data = plot_data.sort_index() if not plot_data.index.is_monotonic_increasing else plot_data
|
| 338 |
x_full, y_full = plot_data.index, plot_data.values
|
| 339 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 340 |
ax.set_title(spec.title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 341 |
+
def init():
|
| 342 |
+
line.set_data([], [])
|
| 343 |
+
markers.set_data([], [])
|
| 344 |
+
return [line, markers]
|
| 345 |
def update(i):
|
| 346 |
k = max(2, int(len(x_full) * (i / (frames - 1))))
|
| 347 |
+
line.set_data(x_full[:k], y_full[:k])
|
| 348 |
+
markers.set_data(x_full[:k], y_full[:k])
|
| 349 |
+
return [line, markers]
|
| 350 |
+
|
| 351 |
anim = FuncAnimation(fig, update, init_func=init, frames=frames, blit=True, interval=1000 / fps)
|
| 352 |
anim.save(str(out), writer=FFMpegWriter(fps=fps), dpi=144)
|
| 353 |
plt.close(fig)
|
|
|
|
| 398 |
finally:
|
| 399 |
list_file.unlink(missing_ok=True)
|
| 400 |
|
| 401 |
+
# --- Main Business Logic Functions ---
|
| 402 |
+
# This section containing generate_report_draft and its helpers is left unchanged as requested.
|
| 403 |
+
# ... (all functions from sanitize_for_firebase_key to generate_single_chart) ...
|
| 404 |
+
# The following functions are preserved exactly as they were in the original code provided.
|
| 405 |
|
|
|
|
| 406 |
def sanitize_for_firebase_key(text: str) -> str:
|
| 407 |
"""Replaces Firebase-forbidden characters in a string with underscores."""
|
| 408 |
forbidden_chars = ['.', '$', '#', '[', ']', '/']
|
|
|
|
| 410 |
text = text.replace(char, '_')
|
| 411 |
return text
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
def analyze_data_intelligence(df: pd.DataFrame, ctx_dict: Dict) -> Dict[str, Any]:
|
| 414 |
"""
|
| 415 |
Autonomous data intelligence system that classifies domain,
|
|
|
|
| 621 |
|
| 622 |
**CHART INTEGRATION:**
|
| 623 |
Insert charts using: `<generate_chart: "chart_type | compelling description that advances the story">`
|
| 624 |
+
Available types: bar, pie, line, scatter, hist, heatmap, area, bubble
|
| 625 |
|
| 626 |
Transform this data into a story that decision-makers can't stop reading."""
|
| 627 |
|
|
|
|
| 646 |
|
| 647 |
# Add specific guidance based on data characteristics
|
| 648 |
if structure['is_timeseries']:
|
| 649 |
+
base_strategy += " Leverage time-series visualizations like line or area charts to show trends and patterns over time."
|
| 650 |
|
| 651 |
if 'correlations' in opportunities:
|
| 652 |
+
base_strategy += " Include correlation visualizations like scatterplots or heatmaps to reveal hidden relationships."
|
| 653 |
|
| 654 |
if 'segmentation' in opportunities:
|
| 655 |
base_strategy += " Use segmented charts to highlight different groups or categories."
|
| 656 |
|
| 657 |
return base_strategy
|
| 658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
def generate_report_draft(buf, name: str, ctx: str, uid: str, project_id: str, bucket):
|
| 660 |
+
# This function remains unchanged as per the instructions.
|
|
|
|
|
|
|
| 661 |
logging.info(f"Generating autonomous report draft for project {project_id}")
|
| 662 |
|
| 663 |
df = load_dataframe_safely(buf, name)
|
| 664 |
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.1)
|
| 665 |
|
|
|
|
| 666 |
ctx_dict = {"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx}
|
| 667 |
enhanced_ctx = enhance_data_context(df, ctx_dict)
|
|
|
|
|
|
|
| 668 |
intelligence = analyze_data_intelligence(df, ctx_dict)
|
|
|
|
|
|
|
| 669 |
report_prompt = create_autonomous_prompt(df, enhanced_ctx, intelligence)
|
|
|
|
|
|
|
| 670 |
md = llm.invoke(report_prompt).content
|
| 671 |
|
|
|
|
| 672 |
chart_descs = extract_chart_tags(md)[:MAX_CHARTS]
|
| 673 |
chart_urls = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
chart_generator = ChartGenerator(llm, df)
|
| 675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
for desc in chart_descs:
|
| 677 |
safe_desc = sanitize_for_firebase_key(desc)
|
| 678 |
md = md.replace(f'<generate_chart: "{desc}">', f'<generate_chart: "{safe_desc}">')
|
| 679 |
md = md.replace(f'<generate_chart: {desc}>', f'<generate_chart: "{safe_desc}">')
|
| 680 |
+
|
| 681 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
| 682 |
img_path = Path(temp_file.name)
|
| 683 |
try:
|
|
|
|
| 687 |
blob = bucket.blob(blob_name)
|
| 688 |
blob.upload_from_filename(str(img_path))
|
| 689 |
chart_urls[safe_desc] = blob.public_url
|
| 690 |
+
logging.info(f"Uploaded chart '{desc}' to {blob.public_url} with safe key '{safe_desc}'")
|
| 691 |
finally:
|
| 692 |
if os.path.exists(img_path):
|
| 693 |
os.unlink(img_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
|
| 695 |
+
return {"raw_md": md, "chartUrls": chart_urls}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
def generate_single_chart(df: pd.DataFrame, description: str, uid: str, project_id: str, bucket):
|
| 698 |
logging.info(f"Generating single chart '{description}' for project {project_id}")
|
|
|
|
| 713 |
os.unlink(img_path)
|
| 714 |
return None
|
| 715 |
|
| 716 |
+
# UPDATED: generate_video_from_project to handle Pexels integration
|
| 717 |
def generate_video_from_project(df: pd.DataFrame, raw_md: str, uid: str, project_id: str, voice_model: str, bucket):
|
| 718 |
logging.info(f"Generating video for project {project_id} with voice {voice_model}")
|
| 719 |
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY, temperature=0.2)
|
| 720 |
+
|
| 721 |
+
# UPDATED: Prompt to create Intro/Conclusion scenes with stock video tags
|
| 722 |
+
story_prompt = f"""
|
| 723 |
+
Based on the following report, create a script for a {VIDEO_SCENES}-scene video.
|
| 724 |
+
1. The first scene MUST be an "Introduction". It must contain narration and a stock video tag like: <generate_stock_video: "search query">.
|
| 725 |
+
2. The last scene MUST be a "Conclusion". It must also contain narration and a stock video tag.
|
| 726 |
+
3. The middle scenes should each contain narration and one chart tag from the report.
|
| 727 |
+
4. Separate each scene with '[SCENE_BREAK]'.
|
| 728 |
+
|
| 729 |
+
Report: {raw_md}
|
| 730 |
+
|
| 731 |
+
Only output the script, no extra text.
|
| 732 |
+
"""
|
| 733 |
script = llm.invoke(story_prompt).content
|
| 734 |
scenes = [s.strip() for s in script.split("[SCENE_BREAK]") if s.strip()]
|
| 735 |
video_parts, audio_parts, temps = [], [], []
|
| 736 |
+
|
| 737 |
+
for i, sc in enumerate(scenes):
|
| 738 |
+
chart_descs = extract_chart_tags(sc)
|
| 739 |
+
pexels_descs = extract_pexels_tags(sc)
|
| 740 |
+
narrative = clean_narration(sc)
|
| 741 |
+
|
| 742 |
+
if not narrative:
|
| 743 |
+
logging.warning(f"Scene {i+1} has no narration, skipping.")
|
| 744 |
+
continue
|
| 745 |
+
|
| 746 |
audio_bytes = deepgram_tts(narrative, voice_model)
|
| 747 |
mp3 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp3"
|
| 748 |
if audio_bytes:
|
|
|
|
| 751 |
else:
|
| 752 |
dur = 5.0; generate_silence_mp3(dur, mp3)
|
| 753 |
audio_parts.append(str(mp3)); temps.append(mp3)
|
| 754 |
+
|
| 755 |
mp4 = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.mp4"
|
| 756 |
+
video_generated = False
|
| 757 |
+
|
| 758 |
+
if pexels_descs:
|
| 759 |
+
logging.info(f"Scene {i+1}: Found Pexels tag '{pexels_descs[0]}'. Searching for video.")
|
| 760 |
+
video_path = search_and_download_pexels_video(pexels_descs[0], dur, mp4)
|
| 761 |
+
if video_path:
|
| 762 |
+
video_parts.append(video_path)
|
| 763 |
+
temps.append(Path(video_path))
|
| 764 |
+
video_generated = True
|
| 765 |
+
|
| 766 |
+
if not video_generated and chart_descs:
|
| 767 |
+
logging.info(f"Scene {i+1}: Found chart tag '{chart_descs[0]}'. Generating chart animation.")
|
| 768 |
+
safe_chart(chart_descs[0], df, dur, mp4)
|
| 769 |
+
video_parts.append(str(mp4)); temps.append(mp4)
|
| 770 |
+
video_generated = True
|
| 771 |
+
|
| 772 |
+
if not video_generated:
|
| 773 |
+
logging.warning(f"Scene {i+1}: No valid chart or stock video tag found. Using fallback image.")
|
| 774 |
img = generate_image_from_prompt(narrative)
|
| 775 |
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
|
| 776 |
animate_image_fade(img_cv, dur, mp4)
|
| 777 |
+
video_parts.append(str(mp4)); temps.append(mp4)
|
| 778 |
|
| 779 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_vid, \
|
| 780 |
tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_aud, \
|