Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
##############################################################################
|
| 2 |
# Sozo Business Studio · 10-Jul-2025
|
| 3 |
-
# • FIXED: Animation and FFmpeg errors
|
| 4 |
# • FIXED: The 'can't multiply sequence' error by replacing the animation engine.
|
| 5 |
# • FIXED: FFmpeg failures with a robust media concatenation function.
|
| 6 |
# • NOTE: The user's prompts, classes, and AI calls are preserved exactly.
|
|
@@ -126,19 +126,6 @@ def build_pdf(md: str, charts: Dict[str, str]) -> bytes:
|
|
| 126 |
pdf.set_font("Arial", "", 11); pdf.write_html(html)
|
| 127 |
return pdf.output(dest="S")
|
| 128 |
|
| 129 |
-
def quick_chart(desc: str, df: pd.DataFrame, out: Path):
|
| 130 |
-
ctype, *rest = [s.strip().lower() for s in desc.split("|", 1)]; ctype = ctype or "bar"
|
| 131 |
-
title = rest[0] if rest else desc
|
| 132 |
-
num_cols = df.select_dtypes("number").columns; cat_cols = df.select_dtypes(exclude="number").columns
|
| 133 |
-
with plt.ioff():
|
| 134 |
-
fig, ax = plt.subplots(figsize=(6, 3.4), dpi=150)
|
| 135 |
-
if ctype == "pie" and len(cat_cols) >= 1 and len(num_cols) >= 1: ax.pie(df.groupby(cat_cols[0])[num_cols[0]].sum().head(8), labels=df.groupby(cat_cols[0])[num_cols[0]].sum().head(8).index, autopct="%1.1f%%", startangle=90)
|
| 136 |
-
elif ctype == "line" and len(num_cols) >= 1: df[num_cols[0]].plot(kind="line", ax=ax)
|
| 137 |
-
elif ctype == "scatter" and len(num_cols) >= 2: ax.scatter(df[num_cols[0]], df[num_cols[1]], s=10, alpha=0.7)
|
| 138 |
-
elif ctype == "hist" and len(num_cols) >= 1: ax.hist(df[num_cols[0]], bins=20, alpha=0.7)
|
| 139 |
-
else: df[num_cols[0]].value_counts().head(10).plot(kind="bar", ax=ax)
|
| 140 |
-
ax.set_title(title); fig.tight_layout(); fig.savefig(out, bbox_inches="tight", facecolor="white"); plt.close(fig)
|
| 141 |
-
|
| 142 |
# ─── ENHANCED CHART GENERATION SYSTEM (User's code - unchanged) ───────────
|
| 143 |
class ChartSpecification:
|
| 144 |
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
|
@@ -195,16 +182,21 @@ class ChartGenerator:
|
|
| 195 |
if response.startswith("```json"): response = response[7:-3]
|
| 196 |
elif response.startswith("```"): response = response[3:-3]
|
| 197 |
spec_dict = json.loads(response)
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
def _create_fallback_spec(self, description: str) -> ChartSpecification:
|
| 202 |
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
|
| 203 |
if "bar" in description.lower() and categorical_cols and numeric_cols: return ChartSpecification("bar", description, categorical_cols[0], numeric_cols[0])
|
| 204 |
elif "pie" in description.lower() and categorical_cols and numeric_cols: return ChartSpecification("pie", description, categorical_cols[0], numeric_cols[0])
|
| 205 |
-
elif "line" in description.lower() and len(numeric_cols) >=
|
| 206 |
elif "scatter" in description.lower() and len(numeric_cols) >= 2: return ChartSpecification("scatter", description, numeric_cols[0], numeric_cols[1])
|
| 207 |
-
elif numeric_cols: return ChartSpecification("hist", description, numeric_cols[0], None)
|
| 208 |
else: return ChartSpecification("bar", description, self.df.columns[0], self.df.columns[1] if len(self.df.columns) > 1 else None)
|
| 209 |
|
| 210 |
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
|
@@ -234,7 +226,7 @@ def prepare_plot_data(spec: ChartSpecification, df: pd.DataFrame) -> pd.Series:
|
|
| 234 |
|
| 235 |
# ─── FIXED ANIMATION SYSTEM ───────────────────────────────────────────────
|
| 236 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 237 |
-
"""FIXED: Renders a reliable animated chart using proven patterns,
|
| 238 |
plot_data = prepare_plot_data(spec, df)
|
| 239 |
title = spec.title
|
| 240 |
frames = max(10, int(dur * fps)) # Ensure integer frame count
|
|
@@ -242,6 +234,7 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 242 |
plt.tight_layout(pad=2.5)
|
| 243 |
ctype = spec.chart_type
|
| 244 |
|
|
|
|
| 245 |
if ctype == "pie":
|
| 246 |
wedges, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 247 |
ax.set_title(title); ax.axis('equal')
|
|
@@ -253,12 +246,16 @@ def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: P
|
|
| 253 |
ax.set_title(title); plt.xticks(rotation=45, ha="right")
|
| 254 |
def init(): return bars
|
| 255 |
def update(i):
|
| 256 |
-
|
|
|
|
| 257 |
return bars
|
| 258 |
else: # line, scatter, hist
|
| 259 |
line, = ax.plot([], [], lw=2)
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
| 262 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 263 |
ax.set_title(title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 264 |
def init(): line.set_data([], []); return [line]
|
|
@@ -282,17 +279,27 @@ def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) ->
|
|
| 282 |
return str(out)
|
| 283 |
|
| 284 |
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
|
| 285 |
-
"""FIXED: A simplified and more reliable chart generation wrapper
|
| 286 |
try:
|
| 287 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 288 |
chart_generator = create_chart_generator(llm, df)
|
| 289 |
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 290 |
-
return animate_chart(chart_spec, df, dur, out
|
| 291 |
except Exception as e:
|
| 292 |
-
print(f"Chart animation failed for '{desc}': {e}. Falling back to
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
def concat_media(file_paths: List[str], output_path: Path, media_type: str):
|
| 298 |
"""FIXED: Concatenate multiple media files using FFmpeg, robustly checking for valid files."""
|
|
|
|
| 1 |
##############################################################################
|
| 2 |
# Sozo Business Studio · 10-Jul-2025
|
| 3 |
+
# • FIXED: Animation and FFmpeg errors while preserving the user's AI architecture.
|
| 4 |
# • FIXED: The 'can't multiply sequence' error by replacing the animation engine.
|
| 5 |
# • FIXED: FFmpeg failures with a robust media concatenation function.
|
| 6 |
# • NOTE: The user's prompts, classes, and AI calls are preserved exactly.
|
|
|
|
| 126 |
pdf.set_font("Arial", "", 11); pdf.write_html(html)
|
| 127 |
return pdf.output(dest="S")
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# ─── ENHANCED CHART GENERATION SYSTEM (User's code - unchanged) ───────────
|
| 130 |
class ChartSpecification:
|
| 131 |
def __init__(self, chart_type: str, title: str, x_col: str, y_col: str, agg_method: str = None, filter_condition: str = None, top_n: int = None, color_scheme: str = "professional"):
|
|
|
|
| 182 |
if response.startswith("```json"): response = response[7:-3]
|
| 183 |
elif response.startswith("```"): response = response[3:-3]
|
| 184 |
spec_dict = json.loads(response)
|
| 185 |
+
# Filter to only include keys expected by the ChartSpecification constructor
|
| 186 |
+
valid_keys = [p.name for p in inspect.signature(ChartSpecification).parameters.values()]
|
| 187 |
+
filtered_dict = {k: v for k, v in spec_dict.items() if k in valid_keys}
|
| 188 |
+
return ChartSpecification(**filtered_dict)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Spec generation failed: {e}. Using fallback.")
|
| 191 |
+
return self._create_fallback_spec(description)
|
| 192 |
|
| 193 |
def _create_fallback_spec(self, description: str) -> ChartSpecification:
|
| 194 |
numeric_cols = self.enhanced_ctx['numeric_columns']; categorical_cols = self.enhanced_ctx['categorical_columns']
|
| 195 |
if "bar" in description.lower() and categorical_cols and numeric_cols: return ChartSpecification("bar", description, categorical_cols[0], numeric_cols[0])
|
| 196 |
elif "pie" in description.lower() and categorical_cols and numeric_cols: return ChartSpecification("pie", description, categorical_cols[0], numeric_cols[0])
|
| 197 |
+
elif "line" in description.lower() and len(numeric_cols) >= 1: return ChartSpecification("line", description, self.df.columns[0], numeric_cols[0])
|
| 198 |
elif "scatter" in description.lower() and len(numeric_cols) >= 2: return ChartSpecification("scatter", description, numeric_cols[0], numeric_cols[1])
|
| 199 |
+
elif "hist" in description.lower() and numeric_cols: return ChartSpecification("hist", description, numeric_cols[0], None)
|
| 200 |
else: return ChartSpecification("bar", description, self.df.columns[0], self.df.columns[1] if len(self.df.columns) > 1 else None)
|
| 201 |
|
| 202 |
def execute_chart_spec(spec: ChartSpecification, df: pd.DataFrame, output_path: Path) -> bool:
|
|
|
|
| 226 |
|
| 227 |
# ─── FIXED ANIMATION SYSTEM ───────────────────────────────────────────────
|
| 228 |
def animate_chart(spec: ChartSpecification, df: pd.DataFrame, dur: float, out: Path, fps: int = FPS) -> str:
|
| 229 |
+
"""FIXED: Renders a reliable animated chart using proven patterns, compatible with ChartSpecification."""
|
| 230 |
plot_data = prepare_plot_data(spec, df)
|
| 231 |
title = spec.title
|
| 232 |
frames = max(10, int(dur * fps)) # Ensure integer frame count
|
|
|
|
| 234 |
plt.tight_layout(pad=2.5)
|
| 235 |
ctype = spec.chart_type
|
| 236 |
|
| 237 |
+
# This robust animation logic is adapted from the working example
|
| 238 |
if ctype == "pie":
|
| 239 |
wedges, _ = ax.pie(plot_data, labels=plot_data.index, startangle=90, autopct='%1.1f%%')
|
| 240 |
ax.set_title(title); ax.axis('equal')
|
|
|
|
| 246 |
ax.set_title(title); plt.xticks(rotation=45, ha="right")
|
| 247 |
def init(): return bars
|
| 248 |
def update(i):
|
| 249 |
+
progress = i / (frames - 1)
|
| 250 |
+
for b, h in zip(bars, plot_data.values): b.set_height(h * progress)
|
| 251 |
return bars
|
| 252 |
else: # line, scatter, hist
|
| 253 |
line, = ax.plot([], [], lw=2)
|
| 254 |
+
if ctype == 'scatter':
|
| 255 |
+
x_full, y_full = plot_data.iloc[:, 0], plot_data.iloc[:, 1]
|
| 256 |
+
else:
|
| 257 |
+
plot_data = plot_data.sort_index() if ctype == 'line' and not plot_data.index.is_monotonic_increasing else plot_data
|
| 258 |
+
x_full, y_full = plot_data.index, plot_data.values
|
| 259 |
ax.set_xlim(x_full.min(), x_full.max()); ax.set_ylim(y_full.min() * 0.9, y_full.max() * 1.1)
|
| 260 |
ax.set_title(title); ax.grid(alpha=.3); ax.set_xlabel(spec.x_col); ax.set_ylabel(spec.y_col)
|
| 261 |
def init(): line.set_data([], []); return [line]
|
|
|
|
| 279 |
return str(out)
|
| 280 |
|
| 281 |
def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
|
| 282 |
+
"""FIXED: A simplified and more reliable chart generation wrapper."""
|
| 283 |
try:
|
| 284 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 285 |
chart_generator = create_chart_generator(llm, df)
|
| 286 |
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 287 |
+
return animate_chart(chart_spec, df, dur, out)
|
| 288 |
except Exception as e:
|
| 289 |
+
print(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
|
| 290 |
+
# Fallback: create a static version of the chart and fade it in
|
| 291 |
+
temp_png = out.with_suffix(".png")
|
| 292 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 293 |
+
chart_generator = create_chart_generator(llm, df)
|
| 294 |
+
chart_spec = chart_generator.generate_chart_spec(desc)
|
| 295 |
+
if execute_chart_spec(chart_spec, df, temp_png):
|
| 296 |
+
img = cv2.imread(str(temp_png))
|
| 297 |
+
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
|
| 298 |
+
return animate_image_fade(img_resized, dur, out)
|
| 299 |
+
else: # Ultimate fallback
|
| 300 |
+
img = generate_image_from_prompt(f"A professional business chart showing {desc}")
|
| 301 |
+
img_cv = cv2.cvtColor(np.array(img.resize((WIDTH, HEIGHT))), cv2.COLOR_RGB2BGR)
|
| 302 |
+
return animate_image_fade(img_cv, dur, out)
|
| 303 |
|
| 304 |
def concat_media(file_paths: List[str], output_path: Path, media_type: str):
|
| 305 |
"""FIXED: Concatenate multiple media files using FFmpeg, robustly checking for valid files."""
|