Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import os | |
| import re | |
| import io | |
| import hashlib | |
| import json | |
| import glob | |
| import base64 | |
| from datetime import datetime | |
| # Visualization | |
| import plotly.graph_objects as go | |
| import plotly.io as pio | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| # Document Generation | |
| from docx import Document | |
| from docx.shared import Inches | |
| # AI & LlamaIndex | |
| from openai import OpenAI as OpenAIClient | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core import Settings | |
| from llama_index.core.tools import QueryEngineTool, FunctionTool, ToolMetadata | |
| from llama_index.agent.openai import OpenAIAgent | |
| from llama_index.experimental.query_engine import PandasQueryEngine | |
| # Force non-interactive backend for Matplotlib to prevent threading issues | |
| matplotlib.use('Agg') | |
| # ========================================== | |
| # βοΈ Configuration & Sandbox Setup | |
| # ========================================== | |
| st.set_page_config(page_title="Data Agent & Sandbox made for Transmed", page_icon="π", layout="wide") | |
| SANDBOX_DIR = "sandbox" | |
| if not os.path.exists(SANDBOX_DIR): | |
| os.makedirs(SANDBOX_DIR) | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "agent" not in st.session_state: | |
| st.session_state.agent = None | |
| if "dataframes" not in st.session_state: | |
| st.session_state.dataframes = {} | |
| if "voice_prompt_text" not in st.session_state: | |
| st.session_state.voice_prompt_text = "" | |
| if "files_fingerprint" not in st.session_state: | |
| st.session_state.files_fingerprint = None | |
| if "api_key_cached" not in st.session_state: | |
| st.session_state.api_key_cached = None | |
| if "audio_key" not in st.session_state: | |
| st.session_state.audio_key = 0 | |
| # ========================================== | |
| # π οΈ Helpers | |
| # ========================================== | |
| def sanitize_name(name): | |
| """Converts filename to a valid python variable name.""" | |
| name = os.path.splitext(name)[0] | |
| clean = re.sub(r'[^a-zA-Z0-9_]', '_', name) | |
| if clean[0].isdigit(): | |
| clean = "df_" + clean | |
| return clean[:60] | |
| def get_llm(api_key): | |
| return OpenAI(model="gpt-4o", api_key=api_key) | |
| def transcribe_audio(api_key, audio_bytes, filename="audio.wav"): | |
| client = OpenAIClient(api_key=api_key) | |
| audio_file = io.BytesIO(audio_bytes) | |
| audio_file.name = filename | |
| resp = client.audio.transcriptions.create( | |
| model="gpt-4o-mini-transcribe", | |
| file=audio_file | |
| ) | |
| return resp.text | |
| def fingerprint_files(files): | |
| hasher = hashlib.md5() | |
| for f in files: | |
| hasher.update(f.name.encode("utf-8")) | |
| hasher.update(str(f.size).encode("utf-8")) | |
| return hasher.hexdigest() | |
| def is_plot_request(text): | |
| return re.search(r"\b(plot|chart|graph|visual|visualize|hist|box|scatter|line|bar)\b", text, re.I) | |
| def add_message(role, content, msg_type="text", **kwargs): | |
| st.session_state.messages.append({ | |
| "role": role, | |
| "content": content, | |
| "type": msg_type, | |
| "timestamp": datetime.now().isoformat(), | |
| **kwargs | |
| }) | |
| def list_sandbox_files(): | |
| files = glob.glob(os.path.join(SANDBOX_DIR, "*")) | |
| files.sort(key=os.path.getmtime, reverse=True) | |
| return files | |
| # ========================================== | |
| # π§ Agent Logic & Tools | |
| # ========================================== | |
| def build_agent(uploaded_files, api_key): | |
| llm = get_llm(api_key) | |
| tools = [] | |
| st.session_state.dataframes = {} | |
| # 1. Load Dataframes | |
| for file in uploaded_files: | |
| safe_name = sanitize_name(file.name) | |
| try: | |
| file.seek(0) | |
| if file.name.endswith(".csv"): | |
| df = pd.read_csv(file) | |
| else: | |
| df = pd.read_excel(file) | |
| # Clean columns | |
| df.columns = [str(c).strip().replace(" ", "_").replace("-", "_") for c in df.columns] | |
| st.session_state.dataframes[safe_name] = df | |
| pandas_engine = PandasQueryEngine( | |
| df=df, | |
| verbose=True, | |
| synthesize_response=True, | |
| llm=llm | |
| ) | |
| tools.append( | |
| QueryEngineTool( | |
| query_engine=pandas_engine, | |
| metadata=ToolMetadata( | |
| name=f"tool_{safe_name}", | |
| description=( | |
| f"Query spreadsheet '{safe_name}'. Use for calculations, filtering, aggregation. " | |
| "Not for plotting." | |
| ) | |
| ) | |
| ) | |
| ) | |
| except Exception as e: | |
| st.error(f"Error loading {file.name}: {e}") | |
| # 2. Plotting Tool (Robust Version) | |
| def plot_generator(code: str): | |
| """ | |
| Executes Python code to generate charts or manipulate data. | |
| """ | |
| try: | |
| # Reset figures to ensure clean slate | |
| plt.close("all") | |
| # --- EXECUTION ENVIRONMENT --- | |
| # We mock plt.show to prevent the agent from clearing the figure buffer | |
| def no_op_show(*args, **kwargs): | |
| pass | |
| # Inject dependencies and dataframes | |
| local_vars = { | |
| "pd": pd, | |
| "plt": plt, | |
| "go": go, | |
| "st": st | |
| } | |
| local_vars.update(st.session_state.dataframes) | |
| # Override show | |
| local_vars["plt"].show = no_op_show | |
| # Execute the code | |
| exec(code, globals(), local_vars) | |
| # --- CAPTURE OUTPUTS --- | |
| plotly_json = None | |
| mpl_png = None | |
| # 1. Check for Plotly 'fig' variable | |
| if "fig" in local_vars: | |
| fig_obj = local_vars["fig"] | |
| # If it's a Plotly figure | |
| if hasattr(fig_obj, "to_json"): | |
| plotly_json = fig_obj.to_json() | |
| # If it's a Matplotlib figure assigned to 'fig' | |
| elif isinstance(fig_obj, plt.Figure): | |
| buf = io.BytesIO() | |
| fig_obj.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| mpl_png = buf.read() | |
| # 2. Fallback: Check active Matplotlib figure (plt.gcf) | |
| # Only if we haven't captured anything yet | |
| if not mpl_png and not plotly_json: | |
| if plt.get_fignums(): | |
| fig = plt.gcf() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| mpl_png = buf.read() | |
| # --- PERSIST TO HISTORY --- | |
| if mpl_png or plotly_json: | |
| add_message( | |
| role="assistant", | |
| content="(chart execution)", | |
| msg_type="chart", | |
| code=code, | |
| plotly_json=plotly_json, | |
| mpl_png=mpl_png | |
| ) | |
| return "Chart generated and saved successfully." | |
| else: | |
| return "Code executed, but no chart was created. Did you assign the plot to 'fig'?" | |
| except Exception as e: | |
| return f"Error executing code: {e}" | |
| tools.append( | |
| FunctionTool.from_defaults( | |
| fn=plot_generator, | |
| name="chart_generator", | |
| description=( | |
| "Create plots or save modified data files. Input must be valid Python code. " | |
| "DO NOT read files. Use the loaded dataframes directly. " | |
| f"Available Dataframes: {', '.join(st.session_state.dataframes.keys())}. " | |
| "To save a file: df.to_csv('sandbox/filename.csv'). " | |
| "To plot: Create the plot using matplotlib (plt) or plotly (go). " | |
| "IMPORTANT: Assign the final figure to a variable named 'fig'. " | |
| "Example: fig = plt.gcf() OR fig = go.Figure(...)" | |
| ) | |
| ) | |
| ) | |
| # 3. Report Generation Tool | |
| def generate_report(summary_text: str, filename: str = "analysis_report.docx"): | |
| """ | |
| Generates a Word document containing the provided summary text and ALL charts | |
| generated in the conversation history. | |
| """ | |
| try: | |
| doc = Document() | |
| doc.add_heading('Data Analysis Report', 0) | |
| doc.add_heading('Executive Summary', level=1) | |
| doc.add_paragraph(summary_text) | |
| doc.add_heading('Visualizations', level=1) | |
| charts_found = 0 | |
| for msg in st.session_state.messages: | |
| if msg.get("type") == "chart": | |
| charts_found += 1 | |
| doc.add_heading(f'Chart #{charts_found}', level=2) | |
| if msg.get("mpl_png"): | |
| image_stream = io.BytesIO(msg["mpl_png"]) | |
| doc.add_picture(image_stream, width=Inches(6)) | |
| elif msg.get("plotly_json"): | |
| try: | |
| fig = go.Figure(json.loads(msg["plotly_json"])) | |
| img_bytes = pio.to_image(fig, format='png') | |
| image_stream = io.BytesIO(img_bytes) | |
| doc.add_picture(image_stream, width=Inches(6)) | |
| except Exception: | |
| doc.add_paragraph("[Plotly chart could not be rendered to image]") | |
| save_path = os.path.join(SANDBOX_DIR, filename) | |
| doc.save(save_path) | |
| return f"Report generated: {save_path}" | |
| except Exception as e: | |
| return f"Failed to generate report: {str(e)}" | |
| tools.append( | |
| FunctionTool.from_defaults( | |
| fn=generate_report, | |
| name="generate_report", | |
| description="Creates a Word DOCX report with a text summary and all charts from the chat history." | |
| ) | |
| ) | |
| df_names = ", ".join(st.session_state.dataframes.keys()) | |
| system_prompt = ( | |
| "You are a Data Science Agent. " | |
| f"The following dataframes are ALREADY loaded: {df_names}. " | |
| "DO NOT read files from disk. Use the variable names directly. " | |
| "1. For calculations, use the dataframe query tool. " | |
| "2. For charts, use 'chart_generator'. " | |
| "3. ALWAYS assign your plot to a variable named 'fig'. " | |
| "4. If the user asks for a report, generate a text summary first, then call 'generate_report'. " | |
| ) | |
| return OpenAIAgent.from_tools( | |
| tools, | |
| llm=llm, | |
| verbose=True, | |
| system_prompt=system_prompt | |
| ) | |
| def ensure_agent(api_key, files): | |
| if not api_key or not files: | |
| st.session_state.agent = None | |
| return | |
| fp = fingerprint_files(files) | |
| if ( | |
| st.session_state.agent is None | |
| or st.session_state.files_fingerprint != fp | |
| or st.session_state.api_key_cached != api_key | |
| ): | |
| with st.spinner("Initializing Agent..."): | |
| Settings.llm = get_llm(api_key) | |
| st.session_state.agent = build_agent(files, api_key) | |
| st.session_state.files_fingerprint = fp | |
| st.session_state.api_key_cached = api_key | |
| st.session_state.messages = [] | |
| st.success("Agent Ready!") | |
| # ========================================== | |
| # π₯οΈ Sidebar | |
| # ========================================== | |
| with st.sidebar: | |
| st.header("1. API Key") | |
| api_key = st.text_input("OpenAI API Key", type="password") | |
| st.header("2. Data") | |
| files = st.file_uploader( | |
| "Upload CSV or Excel", | |
| type=["csv", "xlsx", "xls"], | |
| accept_multiple_files=True | |
| ) | |
| ensure_agent(api_key, files) | |
| if st.session_state.dataframes: | |
| st.divider() | |
| st.write("Loaded Dataframes:") | |
| for name in st.session_state.dataframes: | |
| st.code(name) | |
| st.divider() | |
| st.header("π Sandbox Files") | |
| sandbox_files = list_sandbox_files() | |
| if not sandbox_files: | |
| st.write("No files yet.") | |
| else: | |
| for fpath in sandbox_files: | |
| fname = os.path.basename(fpath) | |
| with open(fpath, "rb") as f: | |
| st.download_button( | |
| label=f"β¬οΈ {fname}", | |
| data=f, | |
| file_name=fname, | |
| mime="application/octet-stream" | |
| ) | |
| st.divider() | |
| st.header("3. Voice Prompt") | |
| col_a, col_b = st.columns([3, 1]) | |
| with col_a: | |
| st.markdown("**Record a voice message**") | |
| with col_b: | |
| if st.button("β» Reset"): | |
| st.session_state.audio_key += 1 | |
| st.session_state.voice_prompt_text = "" | |
| st.rerun() | |
| audio_value = st.audio_input("Record a voice message") | |
| if audio_value is not None: | |
| st.audio(audio_value) | |
| if st.button("π Transcribe Voice"): | |
| if not api_key: | |
| st.error("Please enter your API Key first.") | |
| else: | |
| with st.spinner("Transcribing..."): | |
| try: | |
| audio_bytes = audio_value.getbuffer() | |
| st.session_state.voice_prompt_text = transcribe_audio( | |
| api_key, audio_bytes, filename=audio_value.name or "audio.wav" | |
| ) | |
| st.success("Transcription ready.") | |
| except Exception as e: | |
| st.error(f"Transcription error: {e}") | |
| # ========================================== | |
| # π¬ Chat Interface | |
| # ========================================== | |
| st.title("β‘ Data Agent & Sandbox made for Transmed") | |
| def process_prompt(prompt: str): | |
| add_message("user", prompt, "text") | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| if st.session_state.agent: | |
| with st.chat_message("assistant"): | |
| try: | |
| final_prompt = prompt | |
| if is_plot_request(prompt): | |
| final_prompt += "\n\nIMPORTANT: Call chart_generator. Assign the plot to 'fig'." | |
| response_stream = st.session_state.agent.stream_chat(final_prompt) | |
| full_response = st.write_stream(response_stream.response_gen) | |
| add_message("assistant", full_response, "text") | |
| # Force refresh to show files and charts | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| else: | |
| st.info("Please enter API key and upload files.") | |
| # Render history | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| if msg["type"] == "text": | |
| st.markdown(msg["content"]) | |
| elif msg["type"] == "chart": | |
| st.markdown("**Generated chart code:**") | |
| st.code(msg.get("code", ""), language="python") | |
| if msg.get("plotly_json"): | |
| try: | |
| fig = go.Figure(json.loads(msg["plotly_json"])) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception: | |
| st.warning("Failed to render saved Plotly chart.") | |
| elif msg.get("mpl_png"): | |
| st.image(msg["mpl_png"], use_container_width=True) | |
| else: | |
| st.error("Chart data was not captured correctly.") | |
| if st.session_state.voice_prompt_text: | |
| with st.container(): | |
| st.info(f"Voice prompt ready: {st.session_state.voice_prompt_text}") | |
| if st.button("π¨ Send voice prompt"): | |
| prompt = st.session_state.voice_prompt_text | |
| st.session_state.voice_prompt_text = "" | |
| process_prompt(prompt) | |
| if prompt := st.chat_input("Ask: 'Plot sales, then create a Word report'"): | |
| process_prompt(prompt) |