transmed_openai / src /streamlit_app.py
DrMostafa's picture
Update src/streamlit_app.py
ac605b4 verified
import streamlit as st
import pandas as pd
import os
import re
import io
import hashlib
import json
import glob
import base64
from datetime import datetime
# Visualization
import plotly.graph_objects as go
import plotly.io as pio
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
# Document Generation
from docx import Document
from docx.shared import Inches
# AI & LlamaIndex
from openai import OpenAI as OpenAIClient
from llama_index.llms.openai import OpenAI
from llama_index.core import Settings
from llama_index.core.tools import QueryEngineTool, FunctionTool, ToolMetadata
from llama_index.agent.openai import OpenAIAgent
from llama_index.experimental.query_engine import PandasQueryEngine
# Force non-interactive backend for Matplotlib to prevent threading issues
matplotlib.use('Agg')
# ==========================================
# βš™οΈ Configuration & Sandbox Setup
# ==========================================
st.set_page_config(page_title="Data Agent & Sandbox made for Transmed", page_icon="πŸ“‚", layout="wide")
SANDBOX_DIR = "sandbox"
if not os.path.exists(SANDBOX_DIR):
os.makedirs(SANDBOX_DIR)
if "messages" not in st.session_state:
st.session_state.messages = []
if "agent" not in st.session_state:
st.session_state.agent = None
if "dataframes" not in st.session_state:
st.session_state.dataframes = {}
if "voice_prompt_text" not in st.session_state:
st.session_state.voice_prompt_text = ""
if "files_fingerprint" not in st.session_state:
st.session_state.files_fingerprint = None
if "api_key_cached" not in st.session_state:
st.session_state.api_key_cached = None
if "audio_key" not in st.session_state:
st.session_state.audio_key = 0
# ==========================================
# πŸ› οΈ Helpers
# ==========================================
def sanitize_name(name):
"""Converts filename to a valid python variable name."""
name = os.path.splitext(name)[0]
clean = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if clean[0].isdigit():
clean = "df_" + clean
return clean[:60]
def get_llm(api_key):
return OpenAI(model="gpt-4o", api_key=api_key)
def transcribe_audio(api_key, audio_bytes, filename="audio.wav"):
client = OpenAIClient(api_key=api_key)
audio_file = io.BytesIO(audio_bytes)
audio_file.name = filename
resp = client.audio.transcriptions.create(
model="gpt-4o-mini-transcribe",
file=audio_file
)
return resp.text
def fingerprint_files(files):
hasher = hashlib.md5()
for f in files:
hasher.update(f.name.encode("utf-8"))
hasher.update(str(f.size).encode("utf-8"))
return hasher.hexdigest()
def is_plot_request(text):
return re.search(r"\b(plot|chart|graph|visual|visualize|hist|box|scatter|line|bar)\b", text, re.I)
def add_message(role, content, msg_type="text", **kwargs):
st.session_state.messages.append({
"role": role,
"content": content,
"type": msg_type,
"timestamp": datetime.now().isoformat(),
**kwargs
})
def list_sandbox_files():
files = glob.glob(os.path.join(SANDBOX_DIR, "*"))
files.sort(key=os.path.getmtime, reverse=True)
return files
# ==========================================
# 🧠 Agent Logic & Tools
# ==========================================
def build_agent(uploaded_files, api_key):
llm = get_llm(api_key)
tools = []
st.session_state.dataframes = {}
# 1. Load Dataframes
for file in uploaded_files:
safe_name = sanitize_name(file.name)
try:
file.seek(0)
if file.name.endswith(".csv"):
df = pd.read_csv(file)
else:
df = pd.read_excel(file)
# Clean columns
df.columns = [str(c).strip().replace(" ", "_").replace("-", "_") for c in df.columns]
st.session_state.dataframes[safe_name] = df
pandas_engine = PandasQueryEngine(
df=df,
verbose=True,
synthesize_response=True,
llm=llm
)
tools.append(
QueryEngineTool(
query_engine=pandas_engine,
metadata=ToolMetadata(
name=f"tool_{safe_name}",
description=(
f"Query spreadsheet '{safe_name}'. Use for calculations, filtering, aggregation. "
"Not for plotting."
)
)
)
)
except Exception as e:
st.error(f"Error loading {file.name}: {e}")
# 2. Plotting Tool (Robust Version)
def plot_generator(code: str):
"""
Executes Python code to generate charts or manipulate data.
"""
try:
# Reset figures to ensure clean slate
plt.close("all")
# --- EXECUTION ENVIRONMENT ---
# We mock plt.show to prevent the agent from clearing the figure buffer
def no_op_show(*args, **kwargs):
pass
# Inject dependencies and dataframes
local_vars = {
"pd": pd,
"plt": plt,
"go": go,
"st": st
}
local_vars.update(st.session_state.dataframes)
# Override show
local_vars["plt"].show = no_op_show
# Execute the code
exec(code, globals(), local_vars)
# --- CAPTURE OUTPUTS ---
plotly_json = None
mpl_png = None
# 1. Check for Plotly 'fig' variable
if "fig" in local_vars:
fig_obj = local_vars["fig"]
# If it's a Plotly figure
if hasattr(fig_obj, "to_json"):
plotly_json = fig_obj.to_json()
# If it's a Matplotlib figure assigned to 'fig'
elif isinstance(fig_obj, plt.Figure):
buf = io.BytesIO()
fig_obj.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
mpl_png = buf.read()
# 2. Fallback: Check active Matplotlib figure (plt.gcf)
# Only if we haven't captured anything yet
if not mpl_png and not plotly_json:
if plt.get_fignums():
fig = plt.gcf()
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
mpl_png = buf.read()
# --- PERSIST TO HISTORY ---
if mpl_png or plotly_json:
add_message(
role="assistant",
content="(chart execution)",
msg_type="chart",
code=code,
plotly_json=plotly_json,
mpl_png=mpl_png
)
return "Chart generated and saved successfully."
else:
return "Code executed, but no chart was created. Did you assign the plot to 'fig'?"
except Exception as e:
return f"Error executing code: {e}"
tools.append(
FunctionTool.from_defaults(
fn=plot_generator,
name="chart_generator",
description=(
"Create plots or save modified data files. Input must be valid Python code. "
"DO NOT read files. Use the loaded dataframes directly. "
f"Available Dataframes: {', '.join(st.session_state.dataframes.keys())}. "
"To save a file: df.to_csv('sandbox/filename.csv'). "
"To plot: Create the plot using matplotlib (plt) or plotly (go). "
"IMPORTANT: Assign the final figure to a variable named 'fig'. "
"Example: fig = plt.gcf() OR fig = go.Figure(...)"
)
)
)
# 3. Report Generation Tool
def generate_report(summary_text: str, filename: str = "analysis_report.docx"):
"""
Generates a Word document containing the provided summary text and ALL charts
generated in the conversation history.
"""
try:
doc = Document()
doc.add_heading('Data Analysis Report', 0)
doc.add_heading('Executive Summary', level=1)
doc.add_paragraph(summary_text)
doc.add_heading('Visualizations', level=1)
charts_found = 0
for msg in st.session_state.messages:
if msg.get("type") == "chart":
charts_found += 1
doc.add_heading(f'Chart #{charts_found}', level=2)
if msg.get("mpl_png"):
image_stream = io.BytesIO(msg["mpl_png"])
doc.add_picture(image_stream, width=Inches(6))
elif msg.get("plotly_json"):
try:
fig = go.Figure(json.loads(msg["plotly_json"]))
img_bytes = pio.to_image(fig, format='png')
image_stream = io.BytesIO(img_bytes)
doc.add_picture(image_stream, width=Inches(6))
except Exception:
doc.add_paragraph("[Plotly chart could not be rendered to image]")
save_path = os.path.join(SANDBOX_DIR, filename)
doc.save(save_path)
return f"Report generated: {save_path}"
except Exception as e:
return f"Failed to generate report: {str(e)}"
tools.append(
FunctionTool.from_defaults(
fn=generate_report,
name="generate_report",
description="Creates a Word DOCX report with a text summary and all charts from the chat history."
)
)
df_names = ", ".join(st.session_state.dataframes.keys())
system_prompt = (
"You are a Data Science Agent. "
f"The following dataframes are ALREADY loaded: {df_names}. "
"DO NOT read files from disk. Use the variable names directly. "
"1. For calculations, use the dataframe query tool. "
"2. For charts, use 'chart_generator'. "
"3. ALWAYS assign your plot to a variable named 'fig'. "
"4. If the user asks for a report, generate a text summary first, then call 'generate_report'. "
)
return OpenAIAgent.from_tools(
tools,
llm=llm,
verbose=True,
system_prompt=system_prompt
)
def ensure_agent(api_key, files):
if not api_key or not files:
st.session_state.agent = None
return
fp = fingerprint_files(files)
if (
st.session_state.agent is None
or st.session_state.files_fingerprint != fp
or st.session_state.api_key_cached != api_key
):
with st.spinner("Initializing Agent..."):
Settings.llm = get_llm(api_key)
st.session_state.agent = build_agent(files, api_key)
st.session_state.files_fingerprint = fp
st.session_state.api_key_cached = api_key
st.session_state.messages = []
st.success("Agent Ready!")
# ==========================================
# πŸ–₯️ Sidebar
# ==========================================
with st.sidebar:
st.header("1. API Key")
api_key = st.text_input("OpenAI API Key", type="password")
st.header("2. Data")
files = st.file_uploader(
"Upload CSV or Excel",
type=["csv", "xlsx", "xls"],
accept_multiple_files=True
)
ensure_agent(api_key, files)
if st.session_state.dataframes:
st.divider()
st.write("Loaded Dataframes:")
for name in st.session_state.dataframes:
st.code(name)
st.divider()
st.header("πŸ“‚ Sandbox Files")
sandbox_files = list_sandbox_files()
if not sandbox_files:
st.write("No files yet.")
else:
for fpath in sandbox_files:
fname = os.path.basename(fpath)
with open(fpath, "rb") as f:
st.download_button(
label=f"⬇️ {fname}",
data=f,
file_name=fname,
mime="application/octet-stream"
)
st.divider()
st.header("3. Voice Prompt")
col_a, col_b = st.columns([3, 1])
with col_a:
st.markdown("**Record a voice message**")
with col_b:
if st.button("↻ Reset"):
st.session_state.audio_key += 1
st.session_state.voice_prompt_text = ""
st.rerun()
audio_value = st.audio_input("Record a voice message")
if audio_value is not None:
st.audio(audio_value)
if st.button("πŸ“ Transcribe Voice"):
if not api_key:
st.error("Please enter your API Key first.")
else:
with st.spinner("Transcribing..."):
try:
audio_bytes = audio_value.getbuffer()
st.session_state.voice_prompt_text = transcribe_audio(
api_key, audio_bytes, filename=audio_value.name or "audio.wav"
)
st.success("Transcription ready.")
except Exception as e:
st.error(f"Transcription error: {e}")
# ==========================================
# πŸ’¬ Chat Interface
# ==========================================
st.title("⚑ Data Agent & Sandbox made for Transmed")
def process_prompt(prompt: str):
add_message("user", prompt, "text")
with st.chat_message("user"):
st.markdown(prompt)
if st.session_state.agent:
with st.chat_message("assistant"):
try:
final_prompt = prompt
if is_plot_request(prompt):
final_prompt += "\n\nIMPORTANT: Call chart_generator. Assign the plot to 'fig'."
response_stream = st.session_state.agent.stream_chat(final_prompt)
full_response = st.write_stream(response_stream.response_gen)
add_message("assistant", full_response, "text")
# Force refresh to show files and charts
st.rerun()
except Exception as e:
st.error(f"Error: {e}")
else:
st.info("Please enter API key and upload files.")
# Render history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
if msg["type"] == "text":
st.markdown(msg["content"])
elif msg["type"] == "chart":
st.markdown("**Generated chart code:**")
st.code(msg.get("code", ""), language="python")
if msg.get("plotly_json"):
try:
fig = go.Figure(json.loads(msg["plotly_json"]))
st.plotly_chart(fig, use_container_width=True)
except Exception:
st.warning("Failed to render saved Plotly chart.")
elif msg.get("mpl_png"):
st.image(msg["mpl_png"], use_container_width=True)
else:
st.error("Chart data was not captured correctly.")
if st.session_state.voice_prompt_text:
with st.container():
st.info(f"Voice prompt ready: {st.session_state.voice_prompt_text}")
if st.button("πŸ“¨ Send voice prompt"):
prompt = st.session_state.voice_prompt_text
st.session_state.voice_prompt_text = ""
process_prompt(prompt)
if prompt := st.chat_input("Ask: 'Plot sales, then create a Word report'"):
process_prompt(prompt)