Spaces:
Runtime error
Runtime error
| import os | |
| import streamlit as st | |
| import pandas as pd | |
| import openai | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from dotenv import load_dotenv | |
| import anthropic | |
| import ast | |
| # Load environment variables | |
| load_dotenv() | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY") | |
| # UI Styling | |
| st.markdown( | |
| """ | |
| <style> | |
| .stButton button { | |
| background-color: #1F6FEB; | |
| color: white; | |
| border-radius: 8px; | |
| border: none; | |
| padding: 10px 20px; | |
| font-weight: bold; | |
| } | |
| .stButton button:hover { | |
| background-color: #1A4FC5; | |
| } | |
| .stTextInput > div > input { | |
| border: 1px solid #30363D; | |
| background-color: #161B22; | |
| color: #C9D1D9; | |
| border-radius: 6px; | |
| padding: 10px; | |
| } | |
| .stFileUploader > div { | |
| border: 2px dashed #30363D; | |
| background-color: #161B22; | |
| color: #C9D1D9; | |
| border-radius: 6px; | |
| padding: 10px; | |
| } | |
| .response-box { | |
| background-color: #161B22; | |
| padding: 10px; | |
| border-radius: 6px; | |
| margin-bottom: 10px; | |
| color: #FFFFFF; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.title("Excel Q&A Chatbot π") | |
| # Model Selection | |
| model_choice = st.selectbox("Select LLM Model", ["OpenAI GPT-3.5", "Claude 3 Haiku", "Mistral-7B"]) | |
| # Load appropriate model based on selection | |
| if model_choice == "Mistral-7B": | |
| model_name = "mistralai/Mistral-7B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
| def ask_mistral(query): | |
| inputs = tokenizer(query, return_tensors="pt").to("cuda") | |
| output = model.generate(**inputs) | |
| return tokenizer.decode(output[0]) | |
| elif model_choice == "Claude 3 Haiku": | |
| client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) | |
| def ask_claude(query): | |
| response = client.messages.create( | |
| model="claude-3-haiku", | |
| max_tokens=512, | |
| messages=[{"role": "user", "content": query}] | |
| ) | |
| return response.content[0]["text"] | |
| else: | |
| client = openai.OpenAI() | |
| def ask_gpt(query): | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": query}] | |
| ) | |
| return response.choices[0].message.content | |
| # File Upload | |
| uploaded_file = st.file_uploader("Upload an Excel file", type=["csv", "xlsx"]) | |
| if uploaded_file is not None: | |
| file_extension = uploaded_file.name.split(".")[-1].lower() | |
| df = pd.read_csv(uploaded_file) if file_extension == "csv" else pd.read_excel(uploaded_file, engine="openpyxl") | |
| st.write("### Preview of Data:") | |
| st.write(df.head()) | |
| # Extract metadata | |
| column_names = df.columns.tolist() | |
| data_types = df.dtypes.apply(lambda x: x.name).to_dict() | |
| missing_values = df.isnull().sum().to_dict() | |
| # Display metadata | |
| st.write("### Column Details:") | |
| st.write(pd.DataFrame({"Column": column_names, "Type": data_types.values(), "Missing Values": missing_values.values()})) | |
| # User Query | |
| query = st.text_input("Ask a question about this data:") | |
| if st.button("Submit Query"): | |
| if query: | |
| # Interpret the query using selected LLM | |
| if model_choice == "Mistral-7B": | |
| parsed_query = ask_mistral(f"Convert this question into a Pandas operation: {query}") | |
| elif model_choice == "Claude 3 Haiku": | |
| parsed_query = ask_claude(f"Convert this question into a Pandas operation: {query}") | |
| else: | |
| parsed_query = ask_gpt(f"Convert this question into a Pandas operation: {query}") | |
| # Validate and clean query | |
| parsed_query = parsed_query.strip().replace("`", "") | |
| st.write(f"Parsed Query: `{parsed_query}`") | |
| # Check for column existence if query involves a column | |
| for col in column_names: | |
| if col in parsed_query and col not in df.columns: | |
| st.error(f"Error: Column '{col}' not found in the uploaded file.") | |
| break | |
| else: | |
| # Execute the query | |
| try: | |
| result = eval(parsed_query, {"df": df, "pd": pd}) # Ensuring df is correctly referenced | |
| st.write("### Result:") | |
| st.write(result if isinstance(result, pd.DataFrame) else str(result)) | |
| # If numerical data, show a visualization | |
| if isinstance(result, pd.Series) and result.dtype in ["int64", "float64"]: | |
| fig, ax = plt.subplots() | |
| result.plot(kind="bar", ax=ax) | |
| st.pyplot(fig) | |
| except SyntaxError as e: | |
| st.error(f"Syntax Error in parsed query: {str(e)}") | |
| except Exception as e: | |
| st.error(f"Error executing query: {str(e)}") | |
| # Memory for context retention | |
| if "query_history" not in st.session_state: | |
| st.session_state.query_history = [] | |
| st.session_state.query_history.append(query) | |