SHAMIL SHAHBAZ AWAN commited on
Commit
05d8272
·
verified ·
1 Parent(s): df09e57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -2,9 +2,10 @@ import streamlit as st
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
 
5
  from io import StringIO
6
  import openai
7
- import os
8
 
9
  # Groq API key from secrets
10
  GROQ_API_KEY = os.getenv("HUGGINGFACE_KEY")
@@ -30,27 +31,18 @@ def generate_graph(data, query):
30
  """Generate a graph based on user query."""
31
  try:
32
  fig, ax = plt.subplots(figsize=(10, 6))
33
-
34
- if "correlation" in query.lower():
35
- # Correlation matrix
36
- sns.heatmap(data.corr(), annot=True, cmap="coolwarm", ax=ax)
37
- st.pyplot(fig)
38
- elif "histogram" in query.lower():
39
- # Histogram
40
- column = st.selectbox("Select a column for the histogram", data.columns)
41
- sns.histplot(data[column], kde=True, ax=ax)
42
- st.pyplot(fig)
43
- elif "bar" in query.lower() and "country" in query.lower():
44
- # Bar chart for countries and units sold
45
- if 'country' in data.columns and 'units sold' in data.columns:
46
- country_data = data[['country', 'units sold']].groupby('country').sum().reset_index()
47
- sns.barplot(x='country', y='units sold', data=country_data, ax=ax, color='skyblue')
48
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
49
  st.pyplot(fig)
50
  else:
51
- st.error("The dataset must contain 'country' and 'units sold' columns.")
52
  else:
53
- st.error("Unsupported graph type. Try asking for a correlation matrix, histogram, or bar chart.")
54
  except Exception as e:
55
  st.error(f"Error generating graph: {e}")
56
 
@@ -102,7 +94,7 @@ def main():
102
  st.write("Dataset preview:", data.head())
103
 
104
  # User input for graph generation
105
- query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and units sold')")
106
 
107
  if query:
108
  # Generate the graph based on the query
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
+ import os
6
  from io import StringIO
7
  import openai
8
+ from groq import Groq
9
 
10
  # Groq API key from secrets
11
  GROQ_API_KEY = os.getenv("HUGGINGFACE_KEY")
 
31
  """Generate a graph based on user query."""
32
  try:
33
  fig, ax = plt.subplots(figsize=(10, 6))
34
+
35
+ if "bar" in query.lower() and "gross sales" in query.lower():
36
+ # Bar chart for countries and gross sales
37
+ if 'country' in data.columns and 'gross_sales' in data.columns:
38
+ country_data = data[['country', 'gross_sales']].groupby('country').sum().reset_index()
39
+ sns.barplot(x='country', y='gross_sales', data=country_data, ax=ax, color='skyblue')
 
 
 
 
 
 
 
 
 
40
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
41
  st.pyplot(fig)
42
  else:
43
+ st.error("The dataset must contain 'country' and 'gross_sales' columns.")
44
  else:
45
+ st.error("Unsupported graph type. Try asking for a bar chart.")
46
  except Exception as e:
47
  st.error(f"Error generating graph: {e}")
48
 
 
94
  st.write("Dataset preview:", data.head())
95
 
96
  # User input for graph generation
97
+ query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales')")
98
 
99
  if query:
100
  # Generate the graph based on the query