Chart-bot / app.py
Thunder-rk's picture
Create app.py
9d4b537 verified
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import requests
from dotenv import load_dotenv
import os
# Load environment variables
load_dotenv()
API_KEY = os.getenv("GEMINI_API_KEY")
# Function to get column suggestions from Gemini API
def ask_gemini_for_columns_and_graph(api_key, df, user_query):
"""
Use Gemini API to determine the best columns and graph type based on the user's query.
"""
columns = df.columns.tolist()
prompt = f"""
You are analyzing a CSV file with the following columns: {columns}.
Based on the user's query: "{user_query}",
suggest two columns: one for the X-axis and one for the Y-axis, and the most suitable graph type (e.g., bar, scatter, line, histogram, pie).
Respond in JSON format like this:
{{
"x_column": "ColumnX",
"y_column": "ColumnY",
"graph_type": "graphType"
}}
"""
payload = {
"contents": [
{
"parts": [{"text": prompt}]
}
]
}
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}"
headers = {"Content-Type": "application/json"}
try:
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
content = response.json()
reply = content['candidates'][0]['content']['parts'][0]['text'].strip()
result = eval(reply.replace("```json", "").replace("```", "").strip())
return result['x_column'], result['y_column'], result['graph_type']
except Exception as e:
st.error(f"Error interacting with the Gemini API: {e}")
return None, None, None
# Function to plot the graph
def plot_graph(df, x_column, y_column, graph_type):
plt.figure(figsize=(10, 6))
try:
if graph_type == "bar":
plt.bar(df[x_column], df[y_column], color='skyblue')
plt.xlabel(x_column)
plt.ylabel(y_column)
plt.title(f"Bar Graph: {y_column} vs {x_column}")
plt.xticks(rotation=45)
elif graph_type == "scatter":
plt.scatter(df[x_column], df[y_column], color='skyblue')
plt.xlabel(x_column)
plt.ylabel(y_column)
plt.title(f"Scatter Plot: {y_column} vs {x_column}")
elif graph_type == "line":
plt.plot(df[x_column], df[y_column], color='skyblue', marker='o')
plt.xlabel(x_column)
plt.ylabel(y_column)
plt.title(f"Line Graph: {y_column} vs {x_column}")
elif graph_type == "histogram":
plt.hist(df[y_column], bins=20, color='skyblue', edgecolor='black')
plt.xlabel(y_column)
plt.ylabel("Frequency")
plt.title(f"Histogram of {y_column}")
elif graph_type == "pie":
pie_data = df[x_column].value_counts()
plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%', startangle=90, colors=plt.cm.Paired.colors)
plt.title(f"Pie Chart: Distribution of {x_column}")
else:
st.error(f"Unsupported graph type: {graph_type}")
return
st.pyplot(plt)
except Exception as e:
st.error(f"Error generating the plot: {e}")
# Streamlit Application
def main():
st.title("Interactive Graph Generator")
# File upload
uploaded_file = st.file_uploader("Upload your CSV file", type="csv")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.dataframe(df)
# User query input
user_query = st.text_input("Describe the graph you'd like to generate:")
if st.button("Generate Graph"):
# Get graph suggestions from Gemini API
x_column, y_column, graph_type = ask_gemini_for_columns_and_graph(API_KEY, df, user_query)
if x_column and graph_type:
plot_graph(df, x_column, y_column, graph_type)
else:
st.error("Could not determine columns or graph type.")
if __name__ == "__main__":
main()