Thunder-rk commited on
Commit
9d4b537
·
verified ·
1 Parent(s): f83b90c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import requests
5
+ from dotenv import load_dotenv
6
+ import os
7
+
8
+ # Load environment variables
9
+ load_dotenv()
10
+ API_KEY = os.getenv("GEMINI_API_KEY")
11
+
12
+ # Function to get column suggestions from Gemini API
13
+ def ask_gemini_for_columns_and_graph(api_key, df, user_query):
14
+ """
15
+ Use Gemini API to determine the best columns and graph type based on the user's query.
16
+ """
17
+ columns = df.columns.tolist()
18
+
19
+ prompt = f"""
20
+ You are analyzing a CSV file with the following columns: {columns}.
21
+ Based on the user's query: "{user_query}",
22
+ suggest two columns: one for the X-axis and one for the Y-axis, and the most suitable graph type (e.g., bar, scatter, line, histogram, pie).
23
+ Respond in JSON format like this:
24
+ {{
25
+ "x_column": "ColumnX",
26
+ "y_column": "ColumnY",
27
+ "graph_type": "graphType"
28
+ }}
29
+ """
30
+
31
+ payload = {
32
+ "contents": [
33
+ {
34
+ "parts": [{"text": prompt}]
35
+ }
36
+ ]
37
+ }
38
+
39
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}"
40
+ headers = {"Content-Type": "application/json"}
41
+
42
+ try:
43
+ response = requests.post(url, json=payload, headers=headers)
44
+ response.raise_for_status()
45
+ content = response.json()
46
+ reply = content['candidates'][0]['content']['parts'][0]['text'].strip()
47
+ result = eval(reply.replace("```json", "").replace("```", "").strip())
48
+ return result['x_column'], result['y_column'], result['graph_type']
49
+ except Exception as e:
50
+ st.error(f"Error interacting with the Gemini API: {e}")
51
+ return None, None, None
52
+
53
+ # Function to plot the graph
54
+ def plot_graph(df, x_column, y_column, graph_type):
55
+ plt.figure(figsize=(10, 6))
56
+ try:
57
+ if graph_type == "bar":
58
+ plt.bar(df[x_column], df[y_column], color='skyblue')
59
+ plt.xlabel(x_column)
60
+ plt.ylabel(y_column)
61
+ plt.title(f"Bar Graph: {y_column} vs {x_column}")
62
+ plt.xticks(rotation=45)
63
+ elif graph_type == "scatter":
64
+ plt.scatter(df[x_column], df[y_column], color='skyblue')
65
+ plt.xlabel(x_column)
66
+ plt.ylabel(y_column)
67
+ plt.title(f"Scatter Plot: {y_column} vs {x_column}")
68
+ elif graph_type == "line":
69
+ plt.plot(df[x_column], df[y_column], color='skyblue', marker='o')
70
+ plt.xlabel(x_column)
71
+ plt.ylabel(y_column)
72
+ plt.title(f"Line Graph: {y_column} vs {x_column}")
73
+ elif graph_type == "histogram":
74
+ plt.hist(df[y_column], bins=20, color='skyblue', edgecolor='black')
75
+ plt.xlabel(y_column)
76
+ plt.ylabel("Frequency")
77
+ plt.title(f"Histogram of {y_column}")
78
+ elif graph_type == "pie":
79
+ pie_data = df[x_column].value_counts()
80
+ plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%', startangle=90, colors=plt.cm.Paired.colors)
81
+ plt.title(f"Pie Chart: Distribution of {x_column}")
82
+ else:
83
+ st.error(f"Unsupported graph type: {graph_type}")
84
+ return
85
+ st.pyplot(plt)
86
+ except Exception as e:
87
+ st.error(f"Error generating the plot: {e}")
88
+
89
+ # Streamlit Application
90
+ def main():
91
+ st.title("Interactive Graph Generator")
92
+
93
+ # File upload
94
+ uploaded_file = st.file_uploader("Upload your CSV file", type="csv")
95
+ if uploaded_file is not None:
96
+ df = pd.read_csv(uploaded_file)
97
+ st.dataframe(df)
98
+
99
+ # User query input
100
+ user_query = st.text_input("Describe the graph you'd like to generate:")
101
+
102
+ if st.button("Generate Graph"):
103
+ # Get graph suggestions from Gemini API
104
+ x_column, y_column, graph_type = ask_gemini_for_columns_and_graph(API_KEY, df, user_query)
105
+
106
+ if x_column and graph_type:
107
+ plot_graph(df, x_column, y_column, graph_type)
108
+ else:
109
+ st.error("Could not determine columns or graph type.")
110
+
111
+ if __name__ == "__main__":
112
+ main()