Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from io import StringIO | |
| import openai | |
| import os | |
| # Groq API key from secrets | |
| GROQ_API_KEY = os.getenv("HUGGINGFACE_KEY") | |
| # Function to load the uploaded file (CSV or Excel) | |
| def load_file(uploaded_file): | |
| """Load data from an uploaded file.""" | |
| try: | |
| if uploaded_file.type == "text/csv": | |
| data = pd.read_csv(uploaded_file) | |
| elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
| data = pd.read_excel(uploaded_file) | |
| else: | |
| st.error("Unsupported file type.") | |
| return None | |
| return data | |
| except Exception as e: | |
| st.error(f"Error loading file: {e}") | |
| return None | |
| # Function to generate graph based on user query | |
| def generate_graph(data, query): | |
| """Generate a graph based on user query.""" | |
| try: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| if "correlation" in query.lower(): | |
| # Correlation matrix | |
| sns.heatmap(data.corr(), annot=True, cmap="coolwarm", ax=ax) | |
| st.pyplot(fig) | |
| elif "histogram" in query.lower(): | |
| # Histogram | |
| column = st.selectbox("Select a column for the histogram", data.columns) | |
| sns.histplot(data[column], kde=True, ax=ax) | |
| st.pyplot(fig) | |
| elif "bar" in query.lower() and "country" in query.lower(): | |
| # Bar chart for countries and units sold | |
| if 'country' in data.columns and 'units sold' in data.columns: | |
| country_data = data[['country', 'units sold']].groupby('country').sum().reset_index() | |
| sns.barplot(x='country', y='units sold', data=country_data, ax=ax, color='skyblue') | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') | |
| st.pyplot(fig) | |
| else: | |
| st.error("The dataset must contain 'country' and 'units sold' columns.") | |
| else: | |
| st.error("Unsupported graph type. Try asking for a correlation matrix, histogram, or bar chart.") | |
| except Exception as e: | |
| st.error(f"Error generating graph: {e}") | |
| # Function to query the Groq model (Groq API) | |
| def query_groq_model(prompt): | |
| """Send a query to the Groq model and get a response.""" | |
| try: | |
| # Initialize Groq client | |
| client = Groq(api_key=GROQ_API_KEY) | |
| # Query Groq model for response | |
| chat_completion = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3.1-8b-instant", # Replace with your model | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| st.error(f"Error querying Groq model: {e}") | |
| return None | |
| # Streamlit App Interface | |
| def main(): | |
| st.set_page_config(page_title="Data Augmentation and Visualization", page_icon="📊", layout="wide") | |
| # Set background image | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { | |
| background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png'); | |
| background-size: cover; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True | |
| ) | |
| st.title("Data Augmentation and Visualization with Groq API") | |
| st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True) | |
| # File upload section | |
| uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) | |
| if uploaded_file is not None: | |
| # Load and display data | |
| data = load_file(uploaded_file) | |
| if data is not None: | |
| st.write("Dataset preview:", data.head()) | |
| # User input for graph generation | |
| query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and units sold')") | |
| if query: | |
| # Generate the graph based on the query | |
| generate_graph(data, query) | |
| # User input for Groq model query | |
| model_query = st.text_input("Ask Groq model a question:") | |
| if model_query: | |
| # Query the Groq model and display response | |
| response = query_groq_model(model_query) | |
| if response: | |
| st.write("Groq Model Response:", response) | |
| if __name__ == "__main__": | |
| main() | |