import streamlit as st import pandas as pd import os import re import io import hashlib import json import glob import base64 from datetime import datetime # Visualization import plotly.graph_objects as go import plotly.io as pio import matplotlib import matplotlib.pyplot as plt from PIL import Image # Document Generation from docx import Document from docx.shared import Inches # AI & LlamaIndex from openai import OpenAI as OpenAIClient from llama_index.llms.openai import OpenAI from llama_index.core import Settings from llama_index.core.tools import QueryEngineTool, FunctionTool, ToolMetadata from llama_index.agent.openai import OpenAIAgent from llama_index.experimental.query_engine import PandasQueryEngine # Force non-interactive backend for Matplotlib to prevent threading issues matplotlib.use('Agg') # ========================================== # ⚙️ Configuration & Sandbox Setup # ========================================== st.set_page_config(page_title="Data Agent & Sandbox made for Transmed", page_icon="📂", layout="wide") SANDBOX_DIR = "sandbox" if not os.path.exists(SANDBOX_DIR): os.makedirs(SANDBOX_DIR) if "messages" not in st.session_state: st.session_state.messages = [] if "agent" not in st.session_state: st.session_state.agent = None if "dataframes" not in st.session_state: st.session_state.dataframes = {} if "voice_prompt_text" not in st.session_state: st.session_state.voice_prompt_text = "" if "files_fingerprint" not in st.session_state: st.session_state.files_fingerprint = None if "api_key_cached" not in st.session_state: st.session_state.api_key_cached = None if "audio_key" not in st.session_state: st.session_state.audio_key = 0 # ========================================== # 🛠️ Helpers # ========================================== def sanitize_name(name): """Converts filename to a valid python variable name.""" name = os.path.splitext(name)[0] clean = re.sub(r'[^a-zA-Z0-9_]', '_', name) if clean[0].isdigit(): clean = "df_" + clean return clean[:60] def get_llm(api_key): return OpenAI(model="gpt-4o", api_key=api_key) def transcribe_audio(api_key, audio_bytes, filename="audio.wav"): client = OpenAIClient(api_key=api_key) audio_file = io.BytesIO(audio_bytes) audio_file.name = filename resp = client.audio.transcriptions.create( model="gpt-4o-mini-transcribe", file=audio_file ) return resp.text def fingerprint_files(files): hasher = hashlib.md5() for f in files: hasher.update(f.name.encode("utf-8")) hasher.update(str(f.size).encode("utf-8")) return hasher.hexdigest() def is_plot_request(text): return re.search(r"\b(plot|chart|graph|visual|visualize|hist|box|scatter|line|bar)\b", text, re.I) def add_message(role, content, msg_type="text", **kwargs): st.session_state.messages.append({ "role": role, "content": content, "type": msg_type, "timestamp": datetime.now().isoformat(), **kwargs }) def list_sandbox_files(): files = glob.glob(os.path.join(SANDBOX_DIR, "*")) files.sort(key=os.path.getmtime, reverse=True) return files # ========================================== # 🧠 Agent Logic & Tools # ========================================== def build_agent(uploaded_files, api_key): llm = get_llm(api_key) tools = [] st.session_state.dataframes = {} # 1. Load Dataframes for file in uploaded_files: safe_name = sanitize_name(file.name) try: file.seek(0) if file.name.endswith(".csv"): df = pd.read_csv(file) else: df = pd.read_excel(file) # Clean columns df.columns = [str(c).strip().replace(" ", "_").replace("-", "_") for c in df.columns] st.session_state.dataframes[safe_name] = df pandas_engine = PandasQueryEngine( df=df, verbose=True, synthesize_response=True, llm=llm ) tools.append( QueryEngineTool( query_engine=pandas_engine, metadata=ToolMetadata( name=f"tool_{safe_name}", description=( f"Query spreadsheet '{safe_name}'. Use for calculations, filtering, aggregation. " "Not for plotting." ) ) ) ) except Exception as e: st.error(f"Error loading {file.name}: {e}") # 2. Plotting Tool (Robust Version) def plot_generator(code: str): """ Executes Python code to generate charts or manipulate data. """ try: # Reset figures to ensure clean slate plt.close("all") # --- EXECUTION ENVIRONMENT --- # We mock plt.show to prevent the agent from clearing the figure buffer def no_op_show(*args, **kwargs): pass # Inject dependencies and dataframes local_vars = { "pd": pd, "plt": plt, "go": go, "st": st } local_vars.update(st.session_state.dataframes) # Override show local_vars["plt"].show = no_op_show # Execute the code exec(code, globals(), local_vars) # --- CAPTURE OUTPUTS --- plotly_json = None mpl_png = None # 1. Check for Plotly 'fig' variable if "fig" in local_vars: fig_obj = local_vars["fig"] # If it's a Plotly figure if hasattr(fig_obj, "to_json"): plotly_json = fig_obj.to_json() # If it's a Matplotlib figure assigned to 'fig' elif isinstance(fig_obj, plt.Figure): buf = io.BytesIO() fig_obj.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) mpl_png = buf.read() # 2. Fallback: Check active Matplotlib figure (plt.gcf) # Only if we haven't captured anything yet if not mpl_png and not plotly_json: if plt.get_fignums(): fig = plt.gcf() buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) mpl_png = buf.read() # --- PERSIST TO HISTORY --- if mpl_png or plotly_json: add_message( role="assistant", content="(chart execution)", msg_type="chart", code=code, plotly_json=plotly_json, mpl_png=mpl_png ) return "Chart generated and saved successfully." else: return "Code executed, but no chart was created. Did you assign the plot to 'fig'?" except Exception as e: return f"Error executing code: {e}" tools.append( FunctionTool.from_defaults( fn=plot_generator, name="chart_generator", description=( "Create plots or save modified data files. Input must be valid Python code. " "DO NOT read files. Use the loaded dataframes directly. " f"Available Dataframes: {', '.join(st.session_state.dataframes.keys())}. " "To save a file: df.to_csv('sandbox/filename.csv'). " "To plot: Create the plot using matplotlib (plt) or plotly (go). " "IMPORTANT: Assign the final figure to a variable named 'fig'. " "Example: fig = plt.gcf() OR fig = go.Figure(...)" ) ) ) # 3. Report Generation Tool def generate_report(summary_text: str, filename: str = "analysis_report.docx"): """ Generates a Word document containing the provided summary text and ALL charts generated in the conversation history. """ try: doc = Document() doc.add_heading('Data Analysis Report', 0) doc.add_heading('Executive Summary', level=1) doc.add_paragraph(summary_text) doc.add_heading('Visualizations', level=1) charts_found = 0 for msg in st.session_state.messages: if msg.get("type") == "chart": charts_found += 1 doc.add_heading(f'Chart #{charts_found}', level=2) if msg.get("mpl_png"): image_stream = io.BytesIO(msg["mpl_png"]) doc.add_picture(image_stream, width=Inches(6)) elif msg.get("plotly_json"): try: fig = go.Figure(json.loads(msg["plotly_json"])) img_bytes = pio.to_image(fig, format='png') image_stream = io.BytesIO(img_bytes) doc.add_picture(image_stream, width=Inches(6)) except Exception: doc.add_paragraph("[Plotly chart could not be rendered to image]") save_path = os.path.join(SANDBOX_DIR, filename) doc.save(save_path) return f"Report generated: {save_path}" except Exception as e: return f"Failed to generate report: {str(e)}" tools.append( FunctionTool.from_defaults( fn=generate_report, name="generate_report", description="Creates a Word DOCX report with a text summary and all charts from the chat history." ) ) df_names = ", ".join(st.session_state.dataframes.keys()) system_prompt = ( "You are a Data Science Agent. " f"The following dataframes are ALREADY loaded: {df_names}. " "DO NOT read files from disk. Use the variable names directly. " "1. For calculations, use the dataframe query tool. " "2. For charts, use 'chart_generator'. " "3. ALWAYS assign your plot to a variable named 'fig'. " "4. If the user asks for a report, generate a text summary first, then call 'generate_report'. " ) return OpenAIAgent.from_tools( tools, llm=llm, verbose=True, system_prompt=system_prompt ) def ensure_agent(api_key, files): if not api_key or not files: st.session_state.agent = None return fp = fingerprint_files(files) if ( st.session_state.agent is None or st.session_state.files_fingerprint != fp or st.session_state.api_key_cached != api_key ): with st.spinner("Initializing Agent..."): Settings.llm = get_llm(api_key) st.session_state.agent = build_agent(files, api_key) st.session_state.files_fingerprint = fp st.session_state.api_key_cached = api_key st.session_state.messages = [] st.success("Agent Ready!") # ========================================== # 🖥️ Sidebar # ========================================== with st.sidebar: st.header("1. API Key") api_key = st.text_input("OpenAI API Key", type="password") st.header("2. Data") files = st.file_uploader( "Upload CSV or Excel", type=["csv", "xlsx", "xls"], accept_multiple_files=True ) ensure_agent(api_key, files) if st.session_state.dataframes: st.divider() st.write("Loaded Dataframes:") for name in st.session_state.dataframes: st.code(name) st.divider() st.header("📂 Sandbox Files") sandbox_files = list_sandbox_files() if not sandbox_files: st.write("No files yet.") else: for fpath in sandbox_files: fname = os.path.basename(fpath) with open(fpath, "rb") as f: st.download_button( label=f"⬇️ {fname}", data=f, file_name=fname, mime="application/octet-stream" ) st.divider() st.header("3. Voice Prompt") col_a, col_b = st.columns([3, 1]) with col_a: st.markdown("**Record a voice message**") with col_b: if st.button("↻ Reset"): st.session_state.audio_key += 1 st.session_state.voice_prompt_text = "" st.rerun() audio_value = st.audio_input("Record a voice message") if audio_value is not None: st.audio(audio_value) if st.button("📝 Transcribe Voice"): if not api_key: st.error("Please enter your API Key first.") else: with st.spinner("Transcribing..."): try: audio_bytes = audio_value.getbuffer() st.session_state.voice_prompt_text = transcribe_audio( api_key, audio_bytes, filename=audio_value.name or "audio.wav" ) st.success("Transcription ready.") except Exception as e: st.error(f"Transcription error: {e}") # ========================================== # 💬 Chat Interface # ========================================== st.title("⚡ Data Agent & Sandbox made for Transmed") def process_prompt(prompt: str): add_message("user", prompt, "text") with st.chat_message("user"): st.markdown(prompt) if st.session_state.agent: with st.chat_message("assistant"): try: final_prompt = prompt if is_plot_request(prompt): final_prompt += "\n\nIMPORTANT: Call chart_generator. Assign the plot to 'fig'." response_stream = st.session_state.agent.stream_chat(final_prompt) full_response = st.write_stream(response_stream.response_gen) add_message("assistant", full_response, "text") # Force refresh to show files and charts st.rerun() except Exception as e: st.error(f"Error: {e}") else: st.info("Please enter API key and upload files.") # Render history for msg in st.session_state.messages: with st.chat_message(msg["role"]): if msg["type"] == "text": st.markdown(msg["content"]) elif msg["type"] == "chart": st.markdown("**Generated chart code:**") st.code(msg.get("code", ""), language="python") if msg.get("plotly_json"): try: fig = go.Figure(json.loads(msg["plotly_json"])) st.plotly_chart(fig, use_container_width=True) except Exception: st.warning("Failed to render saved Plotly chart.") elif msg.get("mpl_png"): st.image(msg["mpl_png"], use_container_width=True) else: st.error("Chart data was not captured correctly.") if st.session_state.voice_prompt_text: with st.container(): st.info(f"Voice prompt ready: {st.session_state.voice_prompt_text}") if st.button("📨 Send voice prompt"): prompt = st.session_state.voice_prompt_text st.session_state.voice_prompt_text = "" process_prompt(prompt) if prompt := st.chat_input("Ask: 'Plot sales, then create a Word report'"): process_prompt(prompt)