Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import requests | |
| from dotenv import load_dotenv | |
| import os | |
| # Load environment variables | |
| load_dotenv() | |
| API_KEY = os.getenv("GEMINI_API_KEY") | |
| # Function to get column suggestions from Gemini API | |
| def ask_gemini_for_columns_and_graph(api_key, df, user_query): | |
| """ | |
| Use Gemini API to determine the best columns and graph type based on the user's query. | |
| """ | |
| columns = df.columns.tolist() | |
| prompt = f""" | |
| You are analyzing a CSV file with the following columns: {columns}. | |
| Based on the user's query: "{user_query}", | |
| suggest two columns: one for the X-axis and one for the Y-axis, and the most suitable graph type (e.g., bar, scatter, line, histogram, pie). | |
| Respond in JSON format like this: | |
| {{ | |
| "x_column": "ColumnX", | |
| "y_column": "ColumnY", | |
| "graph_type": "graphType" | |
| }} | |
| """ | |
| payload = { | |
| "contents": [ | |
| { | |
| "parts": [{"text": prompt}] | |
| } | |
| ] | |
| } | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}" | |
| headers = {"Content-Type": "application/json"} | |
| try: | |
| response = requests.post(url, json=payload, headers=headers) | |
| response.raise_for_status() | |
| content = response.json() | |
| reply = content['candidates'][0]['content']['parts'][0]['text'].strip() | |
| result = eval(reply.replace("```json", "").replace("```", "").strip()) | |
| return result['x_column'], result['y_column'], result['graph_type'] | |
| except Exception as e: | |
| st.error(f"Error interacting with the Gemini API: {e}") | |
| return None, None, None | |
| # Function to plot the graph | |
| def plot_graph(df, x_column, y_column, graph_type): | |
| plt.figure(figsize=(10, 6)) | |
| try: | |
| if graph_type == "bar": | |
| plt.bar(df[x_column], df[y_column], color='skyblue') | |
| plt.xlabel(x_column) | |
| plt.ylabel(y_column) | |
| plt.title(f"Bar Graph: {y_column} vs {x_column}") | |
| plt.xticks(rotation=45) | |
| elif graph_type == "scatter": | |
| plt.scatter(df[x_column], df[y_column], color='skyblue') | |
| plt.xlabel(x_column) | |
| plt.ylabel(y_column) | |
| plt.title(f"Scatter Plot: {y_column} vs {x_column}") | |
| elif graph_type == "line": | |
| plt.plot(df[x_column], df[y_column], color='skyblue', marker='o') | |
| plt.xlabel(x_column) | |
| plt.ylabel(y_column) | |
| plt.title(f"Line Graph: {y_column} vs {x_column}") | |
| elif graph_type == "histogram": | |
| plt.hist(df[y_column], bins=20, color='skyblue', edgecolor='black') | |
| plt.xlabel(y_column) | |
| plt.ylabel("Frequency") | |
| plt.title(f"Histogram of {y_column}") | |
| elif graph_type == "pie": | |
| pie_data = df[x_column].value_counts() | |
| plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%', startangle=90, colors=plt.cm.Paired.colors) | |
| plt.title(f"Pie Chart: Distribution of {x_column}") | |
| else: | |
| st.error(f"Unsupported graph type: {graph_type}") | |
| return | |
| st.pyplot(plt) | |
| except Exception as e: | |
| st.error(f"Error generating the plot: {e}") | |
| # Streamlit Application | |
| def main(): | |
| st.title("Interactive Graph Generator") | |
| # File upload | |
| uploaded_file = st.file_uploader("Upload your CSV file", type="csv") | |
| if uploaded_file is not None: | |
| df = pd.read_csv(uploaded_file) | |
| st.dataframe(df) | |
| # User query input | |
| user_query = st.text_input("Describe the graph you'd like to generate:") | |
| if st.button("Generate Graph"): | |
| # Get graph suggestions from Gemini API | |
| x_column, y_column, graph_type = ask_gemini_for_columns_and_graph(API_KEY, df, user_query) | |
| if x_column and graph_type: | |
| plot_graph(df, x_column, y_column, graph_type) | |
| else: | |
| st.error("Could not determine columns or graph type.") | |
| if __name__ == "__main__": | |
| main() | |