SHAMIL SHAHBAZ AWAN commited on
Commit
b5e2763
·
verified ·
1 Parent(s): 05d8272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -4,11 +4,6 @@ import matplotlib.pyplot as plt
4
  import seaborn as sns
5
  import os
6
  from io import StringIO
7
- import openai
8
- from groq import Groq
9
-
10
- # Groq API key from secrets
11
- GROQ_API_KEY = os.getenv("HUGGINGFACE_KEY")
12
 
13
  # Function to load the uploaded file (CSV or Excel)
14
  def load_file(uploaded_file):
@@ -41,33 +36,47 @@ def generate_graph(data, query):
41
  st.pyplot(fig)
42
  else:
43
  st.error("The dataset must contain 'country' and 'gross_sales' columns.")
44
- else:
45
- st.error("Unsupported graph type. Try asking for a bar chart.")
46
- except Exception as e:
47
- st.error(f"Error generating graph: {e}")
48
-
49
- # Function to query the Groq model (Groq API)
50
- def query_groq_model(prompt):
51
- """Send a query to the Groq model and get a response."""
52
- try:
53
- # Initialize Groq client
54
- client = Groq(api_key=GROQ_API_KEY)
 
 
 
 
55
 
56
- # Query Groq model for response
57
- chat_completion = client.chat.completions.create(
58
- messages=[{"role": "user", "content": prompt}],
59
- model="llama-3.1-8b-instant", # Replace with your model
60
- )
 
 
 
 
61
 
62
- return chat_completion.choices[0].message.content
 
 
 
 
 
 
 
63
  except Exception as e:
64
- st.error(f"Error querying Groq model: {e}")
65
- return None
66
 
67
  # Streamlit App Interface
68
  def main():
69
- st.set_page_config(page_title="Data Augmentation and Visualization", page_icon="📊", layout="wide")
70
-
71
  # Set background image
72
  st.markdown(
73
  """
@@ -79,17 +88,17 @@ def main():
79
  </style>
80
  """, unsafe_allow_html=True
81
  )
82
-
83
- st.title("Data Augmentation and Visualization with Groq API")
84
  st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
85
-
86
  # File upload section
87
  uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
88
-
89
  if uploaded_file is not None:
90
  # Load and display data
91
  data = load_file(uploaded_file)
92
-
93
  if data is not None:
94
  st.write("Dataset preview:", data.head())
95
 
@@ -100,14 +109,5 @@ def main():
100
  # Generate the graph based on the query
101
  generate_graph(data, query)
102
 
103
- # User input for Groq model query
104
- model_query = st.text_input("Ask Groq model a question:")
105
-
106
- if model_query:
107
- # Query the Groq model and display response
108
- response = query_groq_model(model_query)
109
- if response:
110
- st.write("Groq Model Response:", response)
111
-
112
  if __name__ == "__main__":
113
  main()
 
4
  import seaborn as sns
5
  import os
6
  from io import StringIO
 
 
 
 
 
7
 
8
  # Function to load the uploaded file (CSV or Excel)
9
  def load_file(uploaded_file):
 
36
  st.pyplot(fig)
37
  else:
38
  st.error("The dataset must contain 'country' and 'gross_sales' columns.")
39
+ elif "line" in query.lower() and "sales trend" in query.lower():
40
+ # Line chart for sales trend over time
41
+ if 'date' in data.columns and 'sales' in data.columns:
42
+ data['date'] = pd.to_datetime(data['date'])
43
+ sales_trend = data.groupby('date')['sales'].sum().reset_index()
44
+ sns.lineplot(x='date', y='sales', data=sales_trend, ax=ax)
45
+ ax.set_title("Sales Trend Over Time")
46
+ st.pyplot(fig)
47
+ else:
48
+ st.error("The dataset must contain 'date' and 'sales' columns.")
49
+ elif "scatter" in query.lower() and "relationship" in query.lower():
50
+ # Scatter plot for relationships
51
+ columns = query.lower().split("between")[-1].strip().split("and")
52
+ x_col = columns[0].strip()
53
+ y_col = columns[1].strip()
54
 
55
+ if x_col in data.columns and y_col in data.columns:
56
+ sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
57
+ ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
58
+ st.pyplot(fig)
59
+ else:
60
+ st.error(f"The dataset must contain '{x_col}' and '{y_col}' columns.")
61
+ elif "histogram" in query.lower():
62
+ # Histogram for a specified column
63
+ column = query.lower().split("for")[-1].strip()
64
 
65
+ if column in data.columns:
66
+ sns.histplot(data[column], bins=20, kde=True, ax=ax, color='green')
67
+ ax.set_title(f"Histogram of {column}")
68
+ st.pyplot(fig)
69
+ else:
70
+ st.error(f"The dataset must contain the column '{column}'.")
71
+ else:
72
+ st.error("Unsupported graph type. Try asking for a bar chart, line chart, scatter plot, or histogram.")
73
  except Exception as e:
74
+ st.error(f"Error generating graph: {e}")
 
75
 
76
  # Streamlit App Interface
77
  def main():
78
+ st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
79
+
80
  # Set background image
81
  st.markdown(
82
  """
 
88
  </style>
89
  """, unsafe_allow_html=True
90
  )
91
+
92
+ st.title("Data Visualization App")
93
  st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
94
+
95
  # File upload section
96
  uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
97
+
98
  if uploaded_file is not None:
99
  # Load and display data
100
  data = load_file(uploaded_file)
101
+
102
  if data is not None:
103
  st.write("Dataset preview:", data.head())
104
 
 
109
  # Generate the graph based on the query
110
  generate_graph(data, query)
111
 
 
 
 
 
 
 
 
 
 
112
  if __name__ == "__main__":
113
  main()