Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -119,7 +119,8 @@ TAG_RE = re.compile(
|
|
| 119 |
extract_chart_tags = lambda t: list(dict.fromkeys(m.group("d").strip()
|
| 120 |
for m in TAG_RE.finditer(t or "")))
|
| 121 |
|
| 122 |
-
|
|
|
|
| 123 |
|
| 124 |
def clean_narration(txt: str) -> str:
|
| 125 |
txt = re_scene.sub("", txt)
|
|
@@ -135,7 +136,6 @@ def placeholder_img() -> Image.Image:
|
|
| 135 |
|
| 136 |
@st.cache_data(show_spinner="Generating image...")
|
| 137 |
def generate_image_from_prompt(prompt: str) -> Image.Image:
|
| 138 |
-
# IMPROVED: Using your original model names for consistency with your environment.
|
| 139 |
model_main = "gemini-2.0-flash-exp-image-generation"
|
| 140 |
model_fallback = "gemini-2.0-flash-preview-image-generation"
|
| 141 |
full_prompt = "A clean business-presentation illustration: " + prompt
|
|
@@ -152,7 +152,6 @@ def generate_image_from_prompt(prompt: str) -> Image.Image:
|
|
| 152 |
return Image.open(io.BytesIO(part.inline_data.data)).convert("RGB")
|
| 153 |
return None
|
| 154 |
except Exception:
|
| 155 |
-
# Silently fail to allow fallback
|
| 156 |
return None
|
| 157 |
|
| 158 |
img = fetch(model_main) or fetch(model_fallback)
|
|
@@ -184,10 +183,8 @@ def build_pdf(md: str, charts: Dict[str, str]) -> bytes:
|
|
| 184 |
|
| 185 |
def generate_report_text(df: pd.DataFrame, ctx: str) -> Tuple[str, List[str]]:
|
| 186 |
"""Generates only the text part of the report. This is the fast, first step."""
|
| 187 |
-
# Using your original model name.
|
| 188 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 189 |
|
| 190 |
-
# IMPROVED: Sending a summary instead of the full dataframe is more efficient and robust.
|
| 191 |
ctx_dict = {
|
| 192 |
"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx or "General business analysis",
|
| 193 |
"data_sample": df.head().to_dict('records'),
|
|
@@ -218,13 +215,11 @@ def generate_report_text(df: pd.DataFrame, ctx: str) -> Tuple[str, List[str]]:
|
|
| 218 |
|
| 219 |
def generate_single_chart(description: str, df: pd.DataFrame) -> str:
|
| 220 |
"""Generates one chart using the agent and returns it as a base64 string. More reliable."""
|
| 221 |
-
# Using your original model name.
|
| 222 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 223 |
agent = create_pandas_dataframe_agent(
|
| 224 |
llm=llm, df=df, verbose=False, allow_dangerous_code=True,
|
| 225 |
agent_type="openai-functions", handle_parsing_errors=True
|
| 226 |
)
|
| 227 |
-
# IMPROVED: A more explicit prompt for the agent leads to more reliable code generation.
|
| 228 |
chart_prompt = f"""
|
| 229 |
Your task is to generate Python code to create a single, static, professional chart using matplotlib based on the provided dataframe `df`.
|
| 230 |
The user's request is: '{description}'.
|
|
@@ -255,7 +250,7 @@ def generate_single_chart(description: str, df: pd.DataFrame) -> str:
|
|
| 255 |
except Exception as e:
|
| 256 |
st.warning(f"Chart generation attempt failed: {e}")
|
| 257 |
plt.close("all")
|
| 258 |
-
return None
|
| 259 |
|
| 260 |
# βββ ANIMATION HELPERS (YOUR ORIGINAL CODE) ββββββββββββββββββββββββββββββββ
|
| 261 |
|
|
@@ -264,7 +259,7 @@ def animate_image_fade(img_cv2: np.ndarray, dur: float, out: Path, fps: int = FP
|
|
| 264 |
vid = cv2.VideoWriter(str(out), cv2.VideoWriter_fourcc(*"mp4v"), fps, (WIDTH, HEIGHT))
|
| 265 |
blank = np.full_like(img_cv2, 255)
|
| 266 |
for i in range(frames):
|
| 267 |
-
a = i / (frames - 1)
|
| 268 |
vid.write(cv2.addWeighted(blank, 1 - a, img_cv2, a, 0))
|
| 269 |
vid.release()
|
| 270 |
return str(out)
|
|
@@ -350,17 +345,19 @@ def safe_chart(desc, df, dur, out):
|
|
| 350 |
with plt.ioff():
|
| 351 |
fig, ax = plt.subplots()
|
| 352 |
try:
|
| 353 |
-
# Attempt a simple plot
|
| 354 |
df.select_dtypes(include=np.number).plot(ax=ax)
|
| 355 |
ax.set_title(desc)
|
| 356 |
-
except:
|
| 357 |
-
# If that fails, just show a text error on the image
|
| 358 |
ax.text(0.5, 0.5, 'Could not render chart', ha='center', va='center')
|
| 359 |
|
| 360 |
p = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.png"
|
| 361 |
fig.savefig(p, bbox_inches="tight"); plt.close(fig)
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
def concat_media(paths: List[str], out: Path, kind="video"):
|
| 366 |
if not paths: return
|
|
@@ -369,8 +366,8 @@ def concat_media(paths: List[str], out: Path, kind="video"):
|
|
| 369 |
for p in paths:
|
| 370 |
if Path(p).exists() and Path(p).stat().st_size > 0:
|
| 371 |
f.write(f"file '{Path(p).resolve().as_posix()}'\n")
|
| 372 |
-
if lst_path.stat().st_size == 0:
|
| 373 |
-
lst_path.unlink()
|
| 374 |
return
|
| 375 |
|
| 376 |
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(lst_path), "-c", "copy", str(out)]
|
|
@@ -470,23 +467,25 @@ def generate_video(buf: bytes, name: str, ctx: str, key: str):
|
|
| 470 |
mode = st.radio("Select Output Format:", ["Report (PDF)", "Video Narrative"], horizontal=True)
|
| 471 |
upl = st.file_uploader("Upload CSV or Excel", type=["csv", "xlsx", "xls"])
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
st.session_state.
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
|
|
|
|
|
|
| 487 |
|
| 488 |
if st.session_state.get("df") is not None:
|
| 489 |
-
with st.expander("π Data Preview"):
|
| 490 |
st.dataframe(arrow_df(st.session_state.df.head()))
|
| 491 |
ctx = st.text_area("Business context or specific instructions (optional)")
|
| 492 |
|
|
@@ -497,7 +496,7 @@ if st.session_state.get("df") is not None:
|
|
| 497 |
st.session_state.report_md = md
|
| 498 |
st.session_state.chart_descs = descs
|
| 499 |
st.rerun()
|
| 500 |
-
else:
|
| 501 |
if st.button("π¬ Generate Video", type="primary"):
|
| 502 |
st.warning("Video generation is a long process and will lock the UI.")
|
| 503 |
with st.spinner("Generating video... This may take several minutes."):
|
|
|
|
| 119 |
extract_chart_tags = lambda t: list(dict.fromkeys(m.group("d").strip()
|
| 120 |
for m in TAG_RE.finditer(t or "")))
|
| 121 |
|
| 122 |
+
# --- FIXED: Escaped the hyphen to treat it as a literal character ---
|
| 123 |
+
re_scene = re.compile(r"^\s*scene\s*\d+[:\.- ]*", re.I)
|
| 124 |
|
| 125 |
def clean_narration(txt: str) -> str:
|
| 126 |
txt = re_scene.sub("", txt)
|
|
|
|
| 136 |
|
| 137 |
@st.cache_data(show_spinner="Generating image...")
|
| 138 |
def generate_image_from_prompt(prompt: str) -> Image.Image:
|
|
|
|
| 139 |
model_main = "gemini-2.0-flash-exp-image-generation"
|
| 140 |
model_fallback = "gemini-2.0-flash-preview-image-generation"
|
| 141 |
full_prompt = "A clean business-presentation illustration: " + prompt
|
|
|
|
| 152 |
return Image.open(io.BytesIO(part.inline_data.data)).convert("RGB")
|
| 153 |
return None
|
| 154 |
except Exception:
|
|
|
|
| 155 |
return None
|
| 156 |
|
| 157 |
img = fetch(model_main) or fetch(model_fallback)
|
|
|
|
| 183 |
|
| 184 |
def generate_report_text(df: pd.DataFrame, ctx: str) -> Tuple[str, List[str]]:
|
| 185 |
"""Generates only the text part of the report. This is the fast, first step."""
|
|
|
|
| 186 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 187 |
|
|
|
|
| 188 |
ctx_dict = {
|
| 189 |
"shape": df.shape, "columns": list(df.columns), "user_ctx": ctx or "General business analysis",
|
| 190 |
"data_sample": df.head().to_dict('records'),
|
|
|
|
| 215 |
|
| 216 |
def generate_single_chart(description: str, df: pd.DataFrame) -> str:
|
| 217 |
"""Generates one chart using the agent and returns it as a base64 string. More reliable."""
|
|
|
|
| 218 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
|
| 219 |
agent = create_pandas_dataframe_agent(
|
| 220 |
llm=llm, df=df, verbose=False, allow_dangerous_code=True,
|
| 221 |
agent_type="openai-functions", handle_parsing_errors=True
|
| 222 |
)
|
|
|
|
| 223 |
chart_prompt = f"""
|
| 224 |
Your task is to generate Python code to create a single, static, professional chart using matplotlib based on the provided dataframe `df`.
|
| 225 |
The user's request is: '{description}'.
|
|
|
|
| 250 |
except Exception as e:
|
| 251 |
st.warning(f"Chart generation attempt failed: {e}")
|
| 252 |
plt.close("all")
|
| 253 |
+
return None
|
| 254 |
|
| 255 |
# βββ ANIMATION HELPERS (YOUR ORIGINAL CODE) ββββββββββββββββββββββββββββββββ
|
| 256 |
|
|
|
|
| 259 |
vid = cv2.VideoWriter(str(out), cv2.VideoWriter_fourcc(*"mp4v"), fps, (WIDTH, HEIGHT))
|
| 260 |
blank = np.full_like(img_cv2, 255)
|
| 261 |
for i in range(frames):
|
| 262 |
+
a = i / (frames - 1) if frames > 1 else 1.0
|
| 263 |
vid.write(cv2.addWeighted(blank, 1 - a, img_cv2, a, 0))
|
| 264 |
vid.release()
|
| 265 |
return str(out)
|
|
|
|
| 345 |
with plt.ioff():
|
| 346 |
fig, ax = plt.subplots()
|
| 347 |
try:
|
|
|
|
| 348 |
df.select_dtypes(include=np.number).plot(ax=ax)
|
| 349 |
ax.set_title(desc)
|
| 350 |
+
except Exception:
|
|
|
|
| 351 |
ax.text(0.5, 0.5, 'Could not render chart', ha='center', va='center')
|
| 352 |
|
| 353 |
p = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}.png"
|
| 354 |
fig.savefig(p, bbox_inches="tight"); plt.close(fig)
|
| 355 |
+
img_path = str(p)
|
| 356 |
+
img = cv2.imread(img_path)
|
| 357 |
+
if img is None: # Handle case where image read fails
|
| 358 |
+
img = np.full((HEIGHT, WIDTH, 3), 230, dtype=np.uint8) # Fallback gray image
|
| 359 |
+
img_resized = cv2.resize(img, (WIDTH, HEIGHT))
|
| 360 |
+
return animate_image_fade(img_resized, dur, out)
|
| 361 |
|
| 362 |
def concat_media(paths: List[str], out: Path, kind="video"):
|
| 363 |
if not paths: return
|
|
|
|
| 366 |
for p in paths:
|
| 367 |
if Path(p).exists() and Path(p).stat().st_size > 0:
|
| 368 |
f.write(f"file '{Path(p).resolve().as_posix()}'\n")
|
| 369 |
+
if not lst_path.is_file() or lst_path.stat().st_size == 0:
|
| 370 |
+
if lst_path.is_file(): lst_path.unlink()
|
| 371 |
return
|
| 372 |
|
| 373 |
cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", str(lst_path), "-c", "copy", str(out)]
|
|
|
|
| 467 |
mode = st.radio("Select Output Format:", ["Report (PDF)", "Video Narrative"], horizontal=True)
|
| 468 |
upl = st.file_uploader("Upload CSV or Excel", type=["csv", "xlsx", "xls"])
|
| 469 |
|
| 470 |
+
if upl:
|
| 471 |
+
file_key = sha1_bytes(upl.getvalue())
|
| 472 |
+
if file_key != st.session_state.current_file_key:
|
| 473 |
+
st.session_state.report_md = None
|
| 474 |
+
st.session_state.chart_descs = []
|
| 475 |
+
st.session_state.generated_charts = {}
|
| 476 |
+
st.session_state.pdf_bytes = None
|
| 477 |
+
st.session_state.bundle = None
|
| 478 |
+
st.session_state.current_file_key = file_key
|
| 479 |
+
df, err = load_dataframe_safely(upl.getvalue(), upl.name)
|
| 480 |
+
if err:
|
| 481 |
+
st.error(f"Error loading data: {err}")
|
| 482 |
+
st.session_state.df = None
|
| 483 |
+
else:
|
| 484 |
+
st.session_state.df = df
|
| 485 |
+
st.rerun()
|
| 486 |
|
| 487 |
if st.session_state.get("df") is not None:
|
| 488 |
+
with st.expander("π Data Preview", expanded=True):
|
| 489 |
st.dataframe(arrow_df(st.session_state.df.head()))
|
| 490 |
ctx = st.text_area("Business context or specific instructions (optional)")
|
| 491 |
|
|
|
|
| 496 |
st.session_state.report_md = md
|
| 497 |
st.session_state.chart_descs = descs
|
| 498 |
st.rerun()
|
| 499 |
+
else: # Video Mode
|
| 500 |
if st.button("π¬ Generate Video", type="primary"):
|
| 501 |
st.warning("Video generation is a long process and will lock the UI.")
|
| 502 |
with st.spinner("Generating video... This may take several minutes."):
|