rithwikreal commited on
Commit
59cee21
·
verified ·
1 Parent(s): b855b87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -49
app.py CHANGED
@@ -5,77 +5,207 @@ import io
5
  import os
6
  import google.generativeai as genai
7
  import gc
 
 
8
 
9
- # Load API key securely from Hugging Face secret
10
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
11
  if not GEMINI_API_KEY:
12
  raise ValueError("Gemini API key not set. Please add GEMINI_API_KEY in Space Secrets.")
13
-
14
  genai.configure(api_key=GEMINI_API_KEY)
15
 
16
- # Keep DataFrame in memory during session
17
  session_df = None
18
 
19
- def load_file(file):
20
- """Load uploaded CSV/XLSX into pandas DataFrame."""
21
- global session_df
 
 
 
22
  if file is None:
23
- return None, "No file uploaded"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
- name = getattr(file, "name", "")
26
- content = file.read()
27
- if name.endswith(".csv") or b"," in content[:200]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  df = pd.read_csv(io.BytesIO(content))
29
  else:
 
30
  df = pd.read_excel(io.BytesIO(content))
31
- session_df = df
32
- return df.head(5), f"File loaded with {df.shape[0]} rows and {df.shape[1]} columns."
33
  except Exception as e:
34
- return None, f"Error reading file: {e}"
35
-
36
- def ask_question(query):
37
- """Send the question + DF structure to Gemini and run returned Python code."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  global session_df
39
  if session_df is None:
40
- return None, "Please upload a file first."
41
 
42
- # Build prompt for Gemini
43
- preview = session_df.head(10).to_csv(index=False)
44
- columns = list(session_df.columns)
45
  prompt = f"""
46
- You are a data analyst.
47
- The user uploaded a dataset with these columns: {columns}.
48
- Here are the first 10 rows:
49
- {preview}
50
 
51
  User question: {query}
52
 
53
- Write Python pandas code (only the code, no explanations, no imports) that answers the question
54
- and assigns the result to a variable named result.
55
- If aggregation is needed, show a DataFrame (not just a number).
56
- Keep the output concise (max 200 rows).
57
  """
58
-
59
  try:
60
- # Ask Gemini to generate code
61
  model = genai.GenerativeModel("gemini-pro")
62
  response = model.generate_content(prompt)
63
  code = response.text.strip("`\n ")
 
 
64
 
65
- # Execute the code safely
66
- local_vars = {"pd": pd, "result": None, "df": session_df.copy()}
 
67
  exec(code, {}, local_vars)
68
- result = local_vars.get("result")
69
-
70
- if isinstance(result, pd.DataFrame):
71
- return result, f"Answer based on your question: {query}"
72
- else:
73
- return None, f"No table returned. Code was:\n{code}"
74
-
75
  except Exception as e:
76
- return None, f"Error: {e}"
77
-
78
- def clear_all():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  global session_df
80
  session_df = None
81
  gc.collect()
@@ -86,24 +216,25 @@ def clear_all():
86
  gr.Textbox.update(value=""),
87
  )
88
 
 
89
  with gr.Blocks() as demo:
90
- gr.Markdown("# Chat with CSV (Gemini-powered, private API key)")
91
  with gr.Row():
92
- file_input = gr.File(label="Upload CSV/XLSX")
93
  load_btn = gr.Button("Load file")
94
- file_preview = gr.Dataframe(headers=None, label="Preview (first 5 rows)")
95
  file_status = gr.Textbox(label="File status")
96
 
97
- query_input = gr.Textbox(label="Ask a question")
98
  ask_btn = gr.Button("Ask Gemini")
99
  result_table = gr.Dataframe(headers=None, label="Result")
100
  status = gr.Textbox(label="Status / Messages")
101
 
102
  clear_btn = gr.Button("Clear / Reset")
103
 
104
- load_btn.click(fn=load_file, inputs=file_input, outputs=[file_preview, file_status])
105
- ask_btn.click(fn=ask_question, inputs=query_input, outputs=[result_table, status])
106
- clear_btn.click(fn=clear_all, outputs=[file_input, file_preview, query_input, result_table])
107
 
108
  if __name__ == "__main__":
109
  demo.launch()
 
5
  import os
6
  import google.generativeai as genai
7
  import gc
8
+ import traceback
9
+ from typing import Tuple, Optional
10
 
11
+ # Load API key from secrets (don't put key in code)
12
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
13
  if not GEMINI_API_KEY:
14
  raise ValueError("Gemini API key not set. Please add GEMINI_API_KEY in Space Secrets.")
 
15
  genai.configure(api_key=GEMINI_API_KEY)
16
 
17
+ # session DataFrame (kept in memory for the session)
18
  session_df = None
19
 
20
+ # ---------------- robust file-reading helper ----------------
21
+ def read_file_bytes_flexible(file) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
22
+ """
23
+ Try many ways to extract raw bytes and filename from the uploaded object.
24
+ Returns: (content_bytes | None, filename | None, error_message | None)
25
+ """
26
  if file is None:
27
+ return None, None, "No file uploaded."
28
+
29
+ # 1) If it's already raw bytes
30
+ if isinstance(file, (bytes, bytearray)):
31
+ return bytes(file), None, None
32
+
33
+ # 2) If object has attribute 'bytes' (some wrappers do)
34
+ try:
35
+ b = getattr(file, "bytes", None)
36
+ if isinstance(b, (bytes, bytearray)):
37
+ # try name if available
38
+ name = getattr(file, "name", None) or getattr(file, "filename", None)
39
+ return bytes(b), name, None
40
+ except Exception:
41
+ pass
42
+
43
+ # 3) If object has attribute 'read' and calling it works
44
+ read_attr = getattr(file, "read", None)
45
+ if callable(read_attr):
46
+ try:
47
+ content = read_attr()
48
+ # some frameworks return coroutine for read() - handle it gracefully
49
+ if hasattr(content, "__await__"):
50
+ # can't await in sync; try file.file.read() below
51
+ pass
52
+ else:
53
+ if isinstance(content, (bytes, bytearray)):
54
+ name = getattr(file, "name", None) or getattr(file, "filename", None)
55
+ return bytes(content), name, None
56
+ # sometimes read() returns str (rare), turn to bytes
57
+ if isinstance(content, str):
58
+ return content.encode("utf-8"), getattr(file, "name", None), None
59
+ except TypeError:
60
+ # read() may require args or be not callable in this context
61
+ pass
62
+ except Exception:
63
+ # ignore and try other ways
64
+ pass
65
+
66
+ # 4) If object has a .file attribute (like starlette UploadFile.file)
67
  try:
68
+ attr_file = getattr(file, "file", None)
69
+ if attr_file is not None and hasattr(attr_file, "read"):
70
+ try:
71
+ content = attr_file.read()
72
+ if isinstance(content, (bytes, bytearray)):
73
+ name = getattr(file, "name", None) or getattr(file, "filename", None)
74
+ return bytes(content), name, None
75
+ except Exception:
76
+ pass
77
+ except Exception:
78
+ pass
79
+
80
+ # 5) If object is a dict-like (some environments)
81
+ try:
82
+ if isinstance(file, dict):
83
+ # common keys
84
+ for k in ("content", "data", "bytes", "file", "body"):
85
+ v = file.get(k)
86
+ if isinstance(v, (bytes, bytearray)):
87
+ name = file.get("name") or file.get("filename")
88
+ return bytes(v), name, None
89
+ if isinstance(v, str) and os.path.exists(v):
90
+ with open(v, "rb") as f:
91
+ return f.read(), os.path.basename(v), None
92
+ except Exception:
93
+ pass
94
+
95
+ # 6) Fallback: try attributes that might contain a path string
96
+ try:
97
+ for attr in ("name", "filename", "path"):
98
+ val = getattr(file, attr, None)
99
+ if isinstance(val, str) and os.path.exists(val):
100
+ with open(val, "rb") as f:
101
+ return f.read(), os.path.basename(val), None
102
+ except Exception:
103
+ pass
104
+
105
+ # 7) Give up with a helpful error (include repr for debugging)
106
+ try:
107
+ rep = repr(file)
108
+ except Exception:
109
+ rep = "<unrepresentable object>"
110
+ return None, None, f"Uploaded file format not supported by this server environment. Object repr: {rep}"
111
+
112
+ # ---------------- load file to DataFrame ----------------
113
+ def load_file(file) -> Tuple[Optional[pd.DataFrame], str]:
114
+ """
115
+ Returns (df or None, status_message).
116
+ """
117
+ global session_df
118
+ content, fname, err = read_file_bytes_flexible(file)
119
+ if err:
120
+ return None, f"Error reading file: {err}"
121
+ if content is None:
122
+ return None, "No bytes could be read from uploaded object."
123
+
124
+ try:
125
+ name = (fname or "").lower()
126
+ # Quick heuristic: csv if filename endswith .csv or bytes contain commas/newlines in header
127
+ if name.endswith(".csv") or (isinstance(content, (bytes, bytearray)) and b"," in content[:200]):
128
  df = pd.read_csv(io.BytesIO(content))
129
  else:
130
+ # assume excel by default
131
  df = pd.read_excel(io.BytesIO(content))
 
 
132
  except Exception as e:
133
+ # include traceback to help debug unusual formats (will show in UI only)
134
+ tb = traceback.format_exc()
135
+ return None, f"Error parsing file into DataFrame: {e}\n{tb}"
136
+ finally:
137
+ try:
138
+ del content
139
+ except Exception:
140
+ pass
141
+ gc.collect()
142
+
143
+ session_df = df
144
+ return df, f"File loaded: {df.shape[0]} rows x {df.shape[1]} columns."
145
+
146
+ # ---------------- Gemini-powered question answering ----------------
147
+ def ask_question_gemini(query: str):
148
+ """
149
+ Sends the user's query and a small preview to Gemini; expects back Python code that sets `result`.
150
+ Executes the code in a controlled local environment.
151
+ """
152
  global session_df
153
  if session_df is None:
154
+ return None, "Please upload and load a file first."
155
 
156
+ # build prompt: include columns & small preview
157
+ cols = list(session_df.columns)
158
+ preview_csv = session_df.head(10).to_csv(index=False)
159
  prompt = f"""
160
+ You are a helpful Python data analyst. The user uploaded a dataset with columns: {cols}.
161
+ Here are the first 10 rows (CSV):
162
+ {preview_csv}
 
163
 
164
  User question: {query}
165
 
166
+ Return ONLY Python code (no explanations) that when executed will create a pandas DataFrame named `result`
167
+ that contains the answer (a DataFrame, up to 200 rows). Use `df` as the variable for the dataset.
168
+ Do not import libraries; assume pandas is available as pd. If you need to compute percentages, include them as columns.
169
+ If the query asks for a single number, return it as a one-row DataFrame, e.g. pd.DataFrame({'value':[...]}).
170
  """
 
171
  try:
 
172
  model = genai.GenerativeModel("gemini-pro")
173
  response = model.generate_content(prompt)
174
  code = response.text.strip("`\n ")
175
+ except Exception as e:
176
+ return None, f"Error calling Gemini: {e}"
177
 
178
+ # Execute the code in a controlled namespace
179
+ local_vars = {"pd": pd, "df": session_df.copy(), "result": None}
180
+ try:
181
  exec(code, {}, local_vars)
 
 
 
 
 
 
 
182
  except Exception as e:
183
+ tb = traceback.format_exc()
184
+ return None, f"Error executing code returned by Gemini: {e}\nCode was:\n{code}\n\nTraceback:\n{tb}"
185
+
186
+ result = local_vars.get("result", None)
187
+ if isinstance(result, pd.DataFrame):
188
+ # limit to 200 rows to avoid huge outputs
189
+ return result.head(200), f"Success — executed Gemini code."
190
+ else:
191
+ # If not a DataFrame, try to wrap scalar into DF
192
+ if isinstance(result, (int, float, str)):
193
+ return pd.DataFrame({"value": [result]}), "Gemini returned a scalar; wrapped into DataFrame."
194
+ return None, f"Gemini did not return a DataFrame. Code was:\n{code}"
195
+
196
+ # ---------------- Gradio functions ----------------
197
+ def fn_load(file):
198
+ df, msg = load_file(file)
199
+ if df is None:
200
+ return None, msg
201
+ preview = df.head(5)
202
+ return preview, msg
203
+
204
+ def fn_ask(query):
205
+ res, msg = ask_question_gemini(query)
206
+ return res, msg
207
+
208
+ def fn_clear():
209
  global session_df
210
  session_df = None
211
  gc.collect()
 
216
  gr.Textbox.update(value=""),
217
  )
218
 
219
+ # ---------------- UI ----------------
220
  with gr.Blocks() as demo:
221
+ gr.Markdown("# Chat-with-CSV Gemini-powered (secure API key via Secrets)")
222
  with gr.Row():
223
+ file_input = gr.File(label="Upload CSV or XLSX (will not be saved)")
224
  load_btn = gr.Button("Load file")
225
+ preview_table = gr.Dataframe(headers=None, label="Preview (first 5 rows)")
226
  file_status = gr.Textbox(label="File status")
227
 
228
+ query_input = gr.Textbox(label="Ask a question (English)")
229
  ask_btn = gr.Button("Ask Gemini")
230
  result_table = gr.Dataframe(headers=None, label="Result")
231
  status = gr.Textbox(label="Status / Messages")
232
 
233
  clear_btn = gr.Button("Clear / Reset")
234
 
235
+ load_btn.click(fn=fn_load, inputs=file_input, outputs=[preview_table, file_status])
236
+ ask_btn.click(fn=fn_ask, inputs=query_input, outputs=[result_table, status])
237
+ clear_btn.click(fn=fn_clear, outputs=[file_input, preview_table, query_input, result_table])
238
 
239
  if __name__ == "__main__":
240
  demo.launch()