rairo commited on
Commit
3b709a4
·
verified ·
1 Parent(s): 0d843e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dotenv import load_dotenv
3
+ import os
4
+ import pandas as pd
5
+ from pandasai import SmartDataframe, SmartDatalake
6
+ from pandasai.responses.response_parser import ResponseParser
7
+ from pandasai.llm import GoogleGemini
8
+ import plotly.graph_objects as go
9
+ from PIL import Image
10
+ import io
11
+ import base64
12
+ import requests
13
+
14
+ # API Endpoint
15
+ API_URL = "https://irisplus.elixir.co.zw/public/api/profile/reporting/stock-card/genericReports"
16
+ PAYLOAD = {
17
+ "stock_card_report_id": "d2f1a0e1-7be1-472c-9610-94287154e544"
18
+ }
19
+
20
+ def fetch_data():
21
+ """Fetch stock card report data from API, return cleaned DataFrame"""
22
+ response = requests.post(API_URL, data=PAYLOAD)
23
+ if response.status_code == 200:
24
+ try:
25
+ data = response.json()
26
+ if isinstance(data, dict) and 'actual_report' in data and isinstance(data['actual_report'], list):
27
+ df = pd.DataFrame(data['actual_report']) # Convert list to DataFrame
28
+
29
+ # Remove columns where all values are None
30
+ df.dropna(axis=1, how='all', inplace=True)
31
+
32
+ return df
33
+ else:
34
+ st.error("Unexpected response format from API.")
35
+ return None
36
+ except ValueError:
37
+ st.error("Error: Response is not valid JSON.")
38
+ return None
39
+ else:
40
+ st.error(f"Error fetching data: {response.status_code} - {response.text}")
41
+ return None
42
+
43
+
44
+ class StreamLitResponse(ResponseParser):
45
+ def __init__(self, context):
46
+ super().__init__(context)
47
+
48
+ def format_dataframe(self, result):
49
+ """Enhanced DataFrame rendering with type identifier"""
50
+ return {
51
+ 'type': 'dataframe',
52
+ 'value': result['value']
53
+ }
54
+
55
+ def format_plot(self, result):
56
+ """Enhanced plot rendering with type identifier"""
57
+ try:
58
+ image = result['value']
59
+
60
+ # Convert image to base64 for consistent storage
61
+ if isinstance(image, Image.Image):
62
+ buffered = io.BytesIO()
63
+ image.save(buffered, format="PNG")
64
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
65
+ elif isinstance(image, bytes):
66
+ base64_image = base64.b64encode(image).decode('utf-8')
67
+ elif isinstance(image, str) and os.path.exists(image):
68
+ with open(image, "rb") as f:
69
+ base64_image = base64.b64encode(f.read()).decode('utf-8')
70
+ else:
71
+ return {'type': 'text', 'value': "Unsupported image format"}
72
+
73
+ return {
74
+ 'type': 'plot',
75
+ 'value': base64_image
76
+ }
77
+ except Exception as e:
78
+ return {'type': 'text', 'value': f"Error processing plot: {e}"}
79
+
80
+ def format_other(self, result):
81
+ """Handle other types of responses"""
82
+ return {
83
+ 'type': 'text',
84
+ 'value': str(result['value'])
85
+ }
86
+
87
+ # Load environment variables
88
+ load_dotenv()
89
+ GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
90
+
91
+ if not GOOGLE_API_KEY:
92
+ st.error("GOOGLE_API_KEY environment variable not set.")
93
+ st.stop()
94
+
95
+ def generateResponse(prompt, dfs):
96
+ """Generate response using PandasAI"""
97
+ llm = GoogleGemini(api_key=GOOGLE_API_KEY)
98
+ pandas_agent = SmartDatalake(dfs, config={
99
+ "llm": llm,
100
+ "response_parser": StreamLitResponse
101
+ })
102
+ return pandas_agent.chat(prompt)
103
+
104
+ def render_chat_message(message):
105
+ """Render different types of chat messages"""
106
+ if "dataframe" in message:
107
+ st.dataframe(message["dataframe"])
108
+ elif "plot" in message:
109
+ try:
110
+ # Handle base64 encoded images
111
+ plot_data = message["plot"]
112
+ if isinstance(plot_data, str):
113
+ st.image(f"data:image/png;base64,{plot_data}")
114
+ elif isinstance(plot_data, Image.Image):
115
+ st.image(plot_data)
116
+ elif isinstance(plot_data, go.Figure):
117
+ st.plotly_chart(plot_data)
118
+ elif isinstance(plot_data, bytes):
119
+ image = Image.open(io.BytesIO(plot_data))
120
+ st.image(image)
121
+ else:
122
+ st.write("Unsupported plot format")
123
+ except Exception as e:
124
+ st.error(f"Error rendering plot: {e}")
125
+
126
+ # Always render text content
127
+ if "content" in message:
128
+ st.markdown(message["content"])
129
+
130
+
131
+ def handle_userinput(question, dfs):
132
+ """Enhanced input handling with robust content processing"""
133
+ try:
134
+ # Ensure data is loaded
135
+ if dfs:
136
+ # Append user input to chat history
137
+ st.session_state.chat_history.append({
138
+ "role": "user",
139
+ "content": question
140
+ })
141
+
142
+ # Generate response with PandasAI
143
+ result = generateResponse(question, dfs)
144
+
145
+ # Handle different response types
146
+ if isinstance(result, dict):
147
+ response_type = result.get('type', 'text')
148
+ response_value = result.get('value')
149
+
150
+ if response_type == 'dataframe':
151
+ st.session_state.chat_history.append({
152
+ "role": "assistant",
153
+ "content": "Here's the table:",
154
+ "dataframe": response_value
155
+ })
156
+ elif response_type == 'plot':
157
+ st.session_state.chat_history.append({
158
+ "role": "assistant",
159
+ "content": "Here's the chart:",
160
+ "plot": response_value
161
+ })
162
+ else:
163
+ st.session_state.chat_history.append({
164
+ "role": "assistant",
165
+ "content": str(response_value)
166
+ })
167
+ else:
168
+ st.session_state.chat_history.append({
169
+ "role": "assistant",
170
+ "content": str(result)
171
+ })
172
+ else:
173
+ st.write("No data loaded.")
174
+
175
+ except Exception as e:
176
+ st.error(f"Error processing input: {e}")
177
+
178
+ def main():
179
+ st.set_page_config(page_title="AI Chat with Your Data", page_icon="📊")
180
+
181
+ # Initialize session state variables
182
+ if "chat_history" not in st.session_state:
183
+ st.session_state.chat_history = []
184
+ if "dfs" not in st.session_state:
185
+ st.session_state.dfs = fetch_data() # Load DataFrame at startup
186
+
187
+ st.title("AI Chat with Your Data 📊")
188
+
189
+ # Display chat history
190
+ for message in st.session_state.chat_history:
191
+ with st.chat_message(message["role"]):
192
+ render_chat_message(message)
193
+
194
+ # Chat input
195
+ user_question = st.chat_input("Ask a question about your data:")
196
+
197
+ if user_question:
198
+ handle_userinput(user_question, st.session_state.dfs)
199
+
200
+ # Sidebar with options
201
+ with st.sidebar:
202
+ st.subheader("Options")
203
+
204
+ if st.button("Reload Data"):
205
+ with st.spinner("Fetching latest data..."):
206
+ st.session_state.dfs = fetch_data()
207
+ st.success("Data refreshed!")
208
+
209
+ if st.button("Clear Chat"):
210
+ st.session_state.chat_history = []
211
+ st.rerun()
212
+
213
+ if __name__ == "__main__":
214
+ main()
215
+
216
+