import streamlit as st import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from io import StringIO import openai import os # Groq API key from secrets GROQ_API_KEY = os.getenv("HUGGINGFACE_KEY") # Function to load the uploaded file (CSV or Excel) def load_file(uploaded_file): """Load data from an uploaded file.""" try: if uploaded_file.type == "text/csv": data = pd.read_csv(uploaded_file) elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": data = pd.read_excel(uploaded_file) else: st.error("Unsupported file type.") return None return data except Exception as e: st.error(f"Error loading file: {e}") return None # Function to generate graph based on user query def generate_graph(data, query): """Generate a graph based on user query.""" try: fig, ax = plt.subplots(figsize=(10, 6)) if "correlation" in query.lower(): # Correlation matrix sns.heatmap(data.corr(), annot=True, cmap="coolwarm", ax=ax) st.pyplot(fig) elif "histogram" in query.lower(): # Histogram column = st.selectbox("Select a column for the histogram", data.columns) sns.histplot(data[column], kde=True, ax=ax) st.pyplot(fig) elif "bar" in query.lower() and "country" in query.lower(): # Bar chart for countries and units sold if 'country' in data.columns and 'units sold' in data.columns: country_data = data[['country', 'units sold']].groupby('country').sum().reset_index() sns.barplot(x='country', y='units sold', data=country_data, ax=ax, color='skyblue') ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') st.pyplot(fig) else: st.error("The dataset must contain 'country' and 'units sold' columns.") else: st.error("Unsupported graph type. Try asking for a correlation matrix, histogram, or bar chart.") except Exception as e: st.error(f"Error generating graph: {e}") # Function to query the Groq model (Groq API) def query_groq_model(prompt): """Send a query to the Groq model and get a response.""" try: # Initialize Groq client client = Groq(api_key=GROQ_API_KEY) # Query Groq model for response chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.1-8b-instant", # Replace with your model ) return chat_completion.choices[0].message.content except Exception as e: st.error(f"Error querying Groq model: {e}") return None # Streamlit App Interface def main(): st.set_page_config(page_title="Data Augmentation and Visualization", page_icon="📊", layout="wide") # Set background image st.markdown( """ """, unsafe_allow_html=True ) st.title("Data Augmentation and Visualization with Groq API") st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True) # File upload section uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) if uploaded_file is not None: # Load and display data data = load_file(uploaded_file) if data is not None: st.write("Dataset preview:", data.head()) # User input for graph generation query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and units sold')") if query: # Generate the graph based on the query generate_graph(data, query) # User input for Groq model query model_query = st.text_input("Ask Groq model a question:") if model_query: # Query the Groq model and display response response = query_groq_model(model_query) if response: st.write("Groq Model Response:", response) if __name__ == "__main__": main()