rithwikreal commited on
Commit
b855b87
·
verified ·
1 Parent(s): 9d86fb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -277
app.py CHANGED
@@ -2,308 +2,108 @@
2
  import gradio as gr
3
  import pandas as pd
4
  import io
5
- import re
6
- import gc
7
  import os
8
- from typing import Tuple, Optional, List
 
9
 
10
- # ---------- Helpers for uploaded file reading ----------
11
- def read_uploaded_file(file):
12
- """Try multiple ways to get bytes from Gradio upload objects."""
13
- if file is None:
14
- return None, None, "No file uploaded."
15
- try:
16
- if hasattr(file, "read"):
17
- content = file.read()
18
- name = getattr(file, "name", None)
19
- return content, name, None
20
- except Exception:
21
- pass
22
- try:
23
- if isinstance(file, (str, os.PathLike)):
24
- path = str(file)
25
- if os.path.exists(path):
26
- with open(path, "rb") as f:
27
- content = f.read()
28
- return content, os.path.basename(path), None
29
- except Exception:
30
- pass
31
- try:
32
- if isinstance(file, dict):
33
- name = file.get("name") or file.get("filename")
34
- data = file.get("data") or file.get("content") or file.get("bytes")
35
- if isinstance(data, (bytes, bytearray)):
36
- return data, name, None
37
- if isinstance(data, str) and os.path.exists(data):
38
- with open(data, "rb") as f:
39
- content = f.read()
40
- return content, name or os.path.basename(data), None
41
- except Exception:
42
- pass
43
- try:
44
- name = getattr(file, "name", None)
45
- if name and isinstance(name, str) and os.path.exists(name):
46
- with open(name, "rb") as f:
47
- content = f.read()
48
- return content, os.path.basename(name), None
49
- except Exception:
50
- pass
51
- return None, None, "Uploaded file format not supported by this server environment."
52
 
53
- def load_file_bytes_to_df(file) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
54
- """Read bytes and convert to DataFrame (no disk writes)."""
55
- content, name, err = read_uploaded_file(file)
56
- if err:
57
- return None, f"Error reading file: {err}"
58
- if content is None:
59
- return None, "No content read from uploaded file."
60
  try:
61
- fname = (name or "").lower()
62
- if fname.endswith(".csv") or (isinstance(content, (bytes, bytearray)) and b"," in content[:200]):
 
63
  df = pd.read_csv(io.BytesIO(content))
64
  else:
65
  df = pd.read_excel(io.BytesIO(content))
 
 
66
  except Exception as e:
67
  return None, f"Error reading file: {e}"
68
- finally:
69
- try:
70
- del content
71
- except Exception:
72
- pass
73
- gc.collect()
74
- return df, None
75
-
76
- # ---------- Column matching in queries ----------
77
- def find_columns_in_query(columns: List[str], query: str, max_matches: int = 3) -> List[str]:
78
- """Return a list of best matching column names from the DataFrame for words in the query."""
79
- q = query.lower()
80
- found = []
81
- # exact word matches first
82
- for col in columns:
83
- cl = col.lower()
84
- # exact full word present
85
- if re.search(r"\b" + re.escape(cl) + r"\b", q):
86
- found.append(col)
87
- if len(found) >= max_matches:
88
- return found
89
- # partial matches (any token)
90
- q_tokens = set(re.findall(r"[a-z0-9_]+", q))
91
- for col in columns:
92
- if col in found:
93
- continue
94
- cl = col.lower()
95
- col_tokens = set(re.findall(r"[a-z0-9_]+", cl))
96
- if q_tokens & col_tokens:
97
- found.append(col)
98
- if len(found) >= max_matches:
99
- return found
100
- # fallback: if query contains "department" but no exact column, look for column names containing department
101
- for col in columns:
102
- if "department" in col.lower() and col not in found:
103
- found.append(col)
104
- if len(found) >= max_matches:
105
- return found
106
- return found
107
 
108
- # ---------- Aggregation helpers ----------
109
- def group_count(df: pd.DataFrame, group_col: str, top_n: Optional[int] = None):
110
- res = df.groupby(group_col).size().reset_index(name="count").sort_values("count", ascending=False).reset_index(drop=True)
111
- if top_n:
112
- return res.head(top_n)
113
- return res
114
-
115
- def group_agg(df: pd.DataFrame, group_col: str, value_col: str, agg: str):
116
- if agg in ("mean", "avg", "average"):
117
- res = df.groupby(group_col)[value_col].mean().reset_index().rename(columns={value_col: "average"})
118
- elif agg in ("sum",):
119
- res = df.groupby(group_col)[value_col].sum().reset_index().rename(columns={value_col: "sum"})
120
- elif agg in ("max",):
121
- res = df.groupby(group_col)[value_col].max().reset_index().rename(columns={value_col: "max"})
122
- elif agg in ("min",):
123
- res = df.groupby(group_col)[value_col].min().reset_index().rename(columns={value_col: "min"})
124
- else:
125
- res = df.groupby(group_col)[value_col].agg(agg).reset_index().rename(columns={value_col: agg})
126
- return res.sort_values(res.columns[-1], ascending=False).reset_index(drop=True)
127
-
128
- def compute_percentage_counts(df: pd.DataFrame, group_col: str):
129
- counts = group_count(df, group_col)
130
- total = counts["count"].sum()
131
- counts["percentage"] = (counts["count"] / total * 100).round(2)
132
- return counts
133
-
134
- def compute_percentage_of_value(df: pd.DataFrame, group_col: str, value_col: str):
135
- # percent share of value_col per group
136
- sums = df.groupby(group_col)[value_col].sum().reset_index().rename(columns={value_col: "sum"})
137
- total = sums["sum"].sum()
138
- sums["percentage"] = (sums["sum"] / total * 100).round(2)
139
- return sums.sort_values("sum", ascending=False).reset_index(drop=True)
140
-
141
- # ---------- Natural language parser & action ----------
142
- def simple_nl_to_action(df: pd.DataFrame, query: str):
143
- q = (query or "").strip().lower()
144
- if q == "":
145
- return None, "Please type a question like: 'department wise head count', 'percentage of employees by department', 'average salary by department', or 'show columns'."
146
-
147
- cols = list(df.columns)
148
- matched = find_columns_in_query(cols, q, max_matches=3) # up to 3 column matches
149
-
150
- # direct commands
151
- if "columns" in q or "show columns" in q or "list columns" in q:
152
- return pd.DataFrame({"columns": cols}), None
153
-
154
- # overall totals
155
- if re.search(r"\b(total|how many|count of rows|number of rows|total employees|total employee)\b", q):
156
- return pd.DataFrame({"total_rows": [len(df)]}), None
157
-
158
- # show first N rows
159
- m = re.search(r"(first|head)\s*(\d+)?", q)
160
- if "head" in q or "first" in q:
161
- n = 5
162
- if m and m.group(2):
163
- n = int(m.group(2))
164
- return df.head(n), None
165
-
166
- # describe / summary
167
- if "describe" in q or "summary" in q or "statistics" in q:
168
- return df.describe(include='all').reset_index(), None
169
-
170
- # HEADCOUNT / COUNT requests (department wise head count etc.)
171
- if any(w in q for w in ["headcount", "head count", "head-count", "headcounts", "head count", "number of employees", "how many", "count by", "count of", "count"]):
172
- # If a grouping column is mentioned, use it
173
- if matched:
174
- group_col = matched[0]
175
- # if user mentions percentage as well
176
- if "%" in q or "percentage" in q or "percent" in q or "share" in q:
177
- return compute_percentage_counts(df, group_col), None
178
- # If they asked which has maximum
179
- if any(w in q for w in ["most", "maximum", "max", "highest", "where max", "to which"]):
180
- counts = group_count(df, group_col)
181
- top = counts.head(1)
182
- # also show full counts for context
183
- summary = counts
184
- # build a small output that includes top and summary (we'll return summary; top is first row)
185
- return summary, f"Top: {top.iloc[0,0]} with {top.iloc[0,1]} (rows)."
186
- # just return counts
187
- return group_count(df, group_col), None
188
- else:
189
- # no group column mentioned: return total rows
190
- return pd.DataFrame({"total_rows": [len(df)]}), None
191
-
192
- # AGGREGATION requests (average, mean, sum, max/min of a numeric column grouped by another)
193
- if any(w in q for w in ["average", "mean", "avg", "sum", "total", "maximum", "minimum", "max", "min"]):
194
- # try to detect grouping and value column
195
- if len(matched) >= 2:
196
- group_col = matched[0]
197
- value_col = matched[1]
198
- elif len(matched) == 1:
199
- # ambiguous: user mentioned one column. If that's numeric, perhaps they want overall average
200
- cand = matched[0]
201
- if pd.api.types.is_numeric_dtype(df[cand]):
202
- # overall stat
203
- if any(w in q for w in ["average", "mean", "avg"]):
204
- return pd.DataFrame({f"overall_{cand}_average": [df[cand].mean()]}), None
205
- if "sum" in q or "total" in q:
206
- return pd.DataFrame({f"overall_{cand}_sum": [df[cand].sum()]}), None
207
- # else ask for more clarity
208
- return None, "I found one column but couldn't tell grouping vs value column. Please ask like 'average Salary by Department' or 'sum Sales by Region'."
209
- else:
210
- return None, "Please mention columns. Example: 'average Salary by Department' or 'sum Sales by Region'."
211
- # determine aggregation type
212
- if any(w in q for w in ["average", "mean", "avg"]):
213
- return group_agg(df, group_col, value_col, "mean"), None
214
- if any(w in q for w in ["sum", "total"]):
215
- return group_agg(df, group_col, value_col, "sum"), None
216
- if any(w in q for w in ["max", "maximum", "highest"]):
217
- return group_agg(df, group_col, value_col, "max"), None
218
- if any(w in q for w in ["min", "minimum", "lowest"]):
219
- return group_agg(df, group_col, value_col, "min"), None
220
-
221
- # PERCENTAGE requests for a numeric column per group
222
- if any(w in q for w in ["percentage", "%", "percent", "share"]):
223
- # if two columns mentioned, assume first is group, second is numeric value
224
- if len(matched) >= 2:
225
- group_col = matched[0]
226
- value_col = matched[1]
227
- if pd.api.types.is_numeric_dtype(df[value_col]):
228
- return compute_percentage_of_value(df, group_col, value_col), None
229
- else:
230
- return None, f"Column '{value_col}' is not numeric; cannot compute percentage of values."
231
- elif len(matched) == 1:
232
- group_col = matched[0]
233
- # percent of counts
234
- return compute_percentage_counts(df, group_col), None
235
- else:
236
- return None, "Please mention the group column (and optionally a numeric column). Example: 'percentage of Salary by Department' or 'percentage of employees by Department'."
237
-
238
- # SHOW specific columns (e.g., 'show Department and Salary')
239
- m = re.search(r"show (.+)", q)
240
- if m:
241
- # try to extract column names from matched list
242
- if matched:
243
- # if user asked show with two columns, return them
244
- return df[matched].head(200), None
245
- else:
246
- return None, "Couldn't identify columns to show. Use 'show columns' to view exact names."
247
-
248
- # fallback: return first 10 rows with suggestion
249
- return df.head(10), "Couldn't parse exact request — showing first 10 rows. Try: 'show columns', 'department wise head count', 'percentage of employees by department', or 'average Salary by Department'."
250
-
251
- # ---------- Processing wrapper ----------
252
- def process(file, query):
253
- df, err = load_file_bytes_to_df(file)
254
- if err:
255
- try:
256
- del file
257
- except Exception:
258
- pass
259
- gc.collect()
260
- return None, err
261
 
262
  try:
263
- res, msg = simple_nl_to_action(df, query)
264
- if isinstance(res, pd.DataFrame):
265
- out_df = res.copy()
 
 
 
 
 
 
 
 
 
266
  else:
267
- out_df = None
268
- except Exception as e:
269
- out_df = None
270
- msg = f"Error while processing: {e}"
271
 
272
- try:
273
- del df
274
- del file
275
- except Exception:
276
- pass
277
- gc.collect()
278
-
279
- if isinstance(out_df, pd.DataFrame):
280
- return out_df, (msg or "OK")
281
- else:
282
- return None, (msg or "No result")
283
 
284
- # ---------- Clear / reset ----------
285
  def clear_all():
 
 
 
286
  return (
287
  gr.File.update(value=None),
288
- gr.Textbox.update(value=""),
289
  gr.Dataframe.update(value=None),
290
  gr.Textbox.update(value=""),
 
291
  )
292
 
293
- # ---------- Gradio UI ----------
294
  with gr.Blocks() as demo:
295
- gr.Markdown("# Chat-with-CSV enhanced analysis (ephemeral uploads)")
296
- with gr.Row():
297
- file_input = gr.File(label="Upload CSV or XLSX (will not be saved)", file_count="single")
298
- query_input = gr.Textbox(label="Ask a question (examples: 'department wise head count', 'percentage of Salary by Department', 'average Salary by Department')", placeholder="Type your question here")
299
  with gr.Row():
300
- submit = gr.Button("Run query")
301
- clear_btn = gr.Button("Clear / Reset (remove uploaded file & results)")
302
- output_table = gr.Dataframe(headers=None, label="Result table")
303
- status = gr.Textbox(label="Status / Messages", interactive=False)
 
 
 
 
 
 
 
304
 
305
- submit.click(fn=process, inputs=[file_input, query_input], outputs=[output_table, status])
306
- clear_btn.click(fn=clear_all, inputs=None, outputs=[file_input, query_input, output_table, status])
 
307
 
308
  if __name__ == "__main__":
309
  demo.launch()
 
2
  import gradio as gr
3
  import pandas as pd
4
  import io
 
 
5
  import os
6
+ import google.generativeai as genai
7
+ import gc
8
 
9
+ # Load API key securely from Hugging Face secret
10
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
11
+ if not GEMINI_API_KEY:
12
+ raise ValueError("Gemini API key not set. Please add GEMINI_API_KEY in Space Secrets.")
13
+
14
+ genai.configure(api_key=GEMINI_API_KEY)
15
+
16
+ # Keep DataFrame in memory during session
17
+ session_df = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def load_file(file):
20
+ """Load uploaded CSV/XLSX into pandas DataFrame."""
21
+ global session_df
22
+ if file is None:
23
+ return None, "No file uploaded"
 
 
24
  try:
25
+ name = getattr(file, "name", "")
26
+ content = file.read()
27
+ if name.endswith(".csv") or b"," in content[:200]:
28
  df = pd.read_csv(io.BytesIO(content))
29
  else:
30
  df = pd.read_excel(io.BytesIO(content))
31
+ session_df = df
32
+ return df.head(5), f"File loaded with {df.shape[0]} rows and {df.shape[1]} columns."
33
  except Exception as e:
34
  return None, f"Error reading file: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def ask_question(query):
37
+ """Send the question + DF structure to Gemini and run returned Python code."""
38
+ global session_df
39
+ if session_df is None:
40
+ return None, "Please upload a file first."
41
+
42
+ # Build prompt for Gemini
43
+ preview = session_df.head(10).to_csv(index=False)
44
+ columns = list(session_df.columns)
45
+ prompt = f"""
46
+ You are a data analyst.
47
+ The user uploaded a dataset with these columns: {columns}.
48
+ Here are the first 10 rows:
49
+ {preview}
50
+
51
+ User question: {query}
52
+
53
+ Write Python pandas code (only the code, no explanations, no imports) that answers the question
54
+ and assigns the result to a variable named result.
55
+ If aggregation is needed, show a DataFrame (not just a number).
56
+ Keep the output concise (max 200 rows).
57
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  try:
60
+ # Ask Gemini to generate code
61
+ model = genai.GenerativeModel("gemini-pro")
62
+ response = model.generate_content(prompt)
63
+ code = response.text.strip("`\n ")
64
+
65
+ # Execute the code safely
66
+ local_vars = {"pd": pd, "result": None, "df": session_df.copy()}
67
+ exec(code, {}, local_vars)
68
+ result = local_vars.get("result")
69
+
70
+ if isinstance(result, pd.DataFrame):
71
+ return result, f"Answer based on your question: {query}"
72
  else:
73
+ return None, f"No table returned. Code was:\n{code}"
 
 
 
74
 
75
+ except Exception as e:
76
+ return None, f"Error: {e}"
 
 
 
 
 
 
 
 
 
77
 
 
78
  def clear_all():
79
+ global session_df
80
+ session_df = None
81
+ gc.collect()
82
  return (
83
  gr.File.update(value=None),
 
84
  gr.Dataframe.update(value=None),
85
  gr.Textbox.update(value=""),
86
+ gr.Textbox.update(value=""),
87
  )
88
 
 
89
  with gr.Blocks() as demo:
90
+ gr.Markdown("# Chat with CSV (Gemini-powered, private API key)")
 
 
 
91
  with gr.Row():
92
+ file_input = gr.File(label="Upload CSV/XLSX")
93
+ load_btn = gr.Button("Load file")
94
+ file_preview = gr.Dataframe(headers=None, label="Preview (first 5 rows)")
95
+ file_status = gr.Textbox(label="File status")
96
+
97
+ query_input = gr.Textbox(label="Ask a question")
98
+ ask_btn = gr.Button("Ask Gemini")
99
+ result_table = gr.Dataframe(headers=None, label="Result")
100
+ status = gr.Textbox(label="Status / Messages")
101
+
102
+ clear_btn = gr.Button("Clear / Reset")
103
 
104
+ load_btn.click(fn=load_file, inputs=file_input, outputs=[file_preview, file_status])
105
+ ask_btn.click(fn=ask_question, inputs=query_input, outputs=[result_table, status])
106
+ clear_btn.click(fn=clear_all, outputs=[file_input, file_preview, query_input, result_table])
107
 
108
  if __name__ == "__main__":
109
  demo.launch()