File size: 9,150 Bytes
7595864
 
 
 
a36f327
b855b87
 
59cee21
 
7595864
59cee21
b855b87
 
 
 
 
59cee21
b855b87
a36f327
59cee21
 
 
 
 
 
b855b87
59cee21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7595864
59cee21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7595864
 
59cee21
7595864
 
59cee21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b855b87
 
59cee21
b855b87
59cee21
 
 
b855b87
59cee21
 
 
b855b87
 
 
59cee21
 
 
 
b855b87
7595864
b855b87
 
 
59cee21
 
b855b87
59cee21
 
 
b855b87
 
59cee21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b855b87
 
 
7595864
 
 
 
b855b87
7595864
 
59cee21
7595864
59cee21
7595864
59cee21
b855b87
59cee21
b855b87
 
59cee21
b855b87
 
 
 
 
7595864
59cee21
 
 
7595864
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# app.py
import gradio as gr
import pandas as pd
import io
import os
import google.generativeai as genai
import gc
import traceback
from typing import Tuple, Optional

# Load API key from secrets (don't put key in code)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("Gemini API key not set. Please add GEMINI_API_KEY in Space Secrets.")
genai.configure(api_key=GEMINI_API_KEY)

# session DataFrame (kept in memory for the session)
session_df = None

# ---------------- robust file-reading helper ----------------
def read_file_bytes_flexible(file) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
    """
    Try many ways to extract raw bytes and filename from the uploaded object.
    Returns: (content_bytes | None, filename | None, error_message | None)
    """
    if file is None:
        return None, None, "No file uploaded."

    # 1) If it's already raw bytes
    if isinstance(file, (bytes, bytearray)):
        return bytes(file), None, None

    # 2) If object has attribute 'bytes' (some wrappers do)
    try:
        b = getattr(file, "bytes", None)
        if isinstance(b, (bytes, bytearray)):
            # try name if available
            name = getattr(file, "name", None) or getattr(file, "filename", None)
            return bytes(b), name, None
    except Exception:
        pass

    # 3) If object has attribute 'read' and calling it works
    read_attr = getattr(file, "read", None)
    if callable(read_attr):
        try:
            content = read_attr()
            # some frameworks return coroutine for read() - handle it gracefully
            if hasattr(content, "__await__"):
                # can't await in sync; try file.file.read() below
                pass
            else:
                if isinstance(content, (bytes, bytearray)):
                    name = getattr(file, "name", None) or getattr(file, "filename", None)
                    return bytes(content), name, None
                # sometimes read() returns str (rare), turn to bytes
                if isinstance(content, str):
                    return content.encode("utf-8"), getattr(file, "name", None), None
        except TypeError:
            # read() may require args or be not callable in this context
            pass
        except Exception:
            # ignore and try other ways
            pass

    # 4) If object has a .file attribute (like starlette UploadFile.file)
    try:
        attr_file = getattr(file, "file", None)
        if attr_file is not None and hasattr(attr_file, "read"):
            try:
                content = attr_file.read()
                if isinstance(content, (bytes, bytearray)):
                    name = getattr(file, "name", None) or getattr(file, "filename", None)
                    return bytes(content), name, None
            except Exception:
                pass
    except Exception:
        pass

    # 5) If object is a dict-like (some environments)
    try:
        if isinstance(file, dict):
            # common keys
            for k in ("content", "data", "bytes", "file", "body"):
                v = file.get(k)
                if isinstance(v, (bytes, bytearray)):
                    name = file.get("name") or file.get("filename")
                    return bytes(v), name, None
                if isinstance(v, str) and os.path.exists(v):
                    with open(v, "rb") as f:
                        return f.read(), os.path.basename(v), None
    except Exception:
        pass

    # 6) Fallback: try attributes that might contain a path string
    try:
        for attr in ("name", "filename", "path"):
            val = getattr(file, attr, None)
            if isinstance(val, str) and os.path.exists(val):
                with open(val, "rb") as f:
                    return f.read(), os.path.basename(val), None
    except Exception:
        pass

    # 7) Give up with a helpful error (include repr for debugging)
    try:
        rep = repr(file)
    except Exception:
        rep = "<unrepresentable object>"
    return None, None, f"Uploaded file format not supported by this server environment. Object repr: {rep}"

# ---------------- load file to DataFrame ----------------
def load_file(file) -> Tuple[Optional[pd.DataFrame], str]:
    """
    Returns (df or None, status_message).
    """
    global session_df
    content, fname, err = read_file_bytes_flexible(file)
    if err:
        return None, f"Error reading file: {err}"
    if content is None:
        return None, "No bytes could be read from uploaded object."

    try:
        name = (fname or "").lower()
        # Quick heuristic: csv if filename endswith .csv or bytes contain commas/newlines in header
        if name.endswith(".csv") or (isinstance(content, (bytes, bytearray)) and b"," in content[:200]):
            df = pd.read_csv(io.BytesIO(content))
        else:
            # assume excel by default
            df = pd.read_excel(io.BytesIO(content))
    except Exception as e:
        # include traceback to help debug unusual formats (will show in UI only)
        tb = traceback.format_exc()
        return None, f"Error parsing file into DataFrame: {e}\n{tb}"
    finally:
        try:
            del content
        except Exception:
            pass
        gc.collect()

    session_df = df
    return df, f"File loaded: {df.shape[0]} rows x {df.shape[1]} columns."

# ---------------- Gemini-powered question answering ----------------
def ask_question_gemini(query: str):
    """
    Sends the user's query and a small preview to Gemini; expects back Python code that sets `result`.
    Executes the code in a controlled local environment.
    """
    global session_df
    if session_df is None:
        return None, "Please upload and load a file first."

    # build prompt: include columns & small preview
    cols = list(session_df.columns)
    preview_csv = session_df.head(10).to_csv(index=False)
    prompt = f"""
You are a helpful Python data analyst. The user uploaded a dataset with columns: {cols}.
Here are the first 10 rows (CSV):
{preview_csv}

User question: {query}

Return ONLY Python code (no explanations) that when executed will create a pandas DataFrame named `result`
that contains the answer (a DataFrame, up to 200 rows). Use `df` as the variable for the dataset.
Do not import libraries; assume pandas is available as pd. If you need to compute percentages, include them as columns.
If the query asks for a single number, return it as a one-row DataFrame, e.g. pd.DataFrame({'value':[...]}).
"""
    try:
        model = genai.GenerativeModel("gemini-pro")
        response = model.generate_content(prompt)
        code = response.text.strip("`\n ")
    except Exception as e:
        return None, f"Error calling Gemini: {e}"

    # Execute the code in a controlled namespace
    local_vars = {"pd": pd, "df": session_df.copy(), "result": None}
    try:
        exec(code, {}, local_vars)
    except Exception as e:
        tb = traceback.format_exc()
        return None, f"Error executing code returned by Gemini: {e}\nCode was:\n{code}\n\nTraceback:\n{tb}"

    result = local_vars.get("result", None)
    if isinstance(result, pd.DataFrame):
        # limit to 200 rows to avoid huge outputs
        return result.head(200), f"Success — executed Gemini code."
    else:
        # If not a DataFrame, try to wrap scalar into DF
        if isinstance(result, (int, float, str)):
            return pd.DataFrame({"value": [result]}), "Gemini returned a scalar; wrapped into DataFrame."
        return None, f"Gemini did not return a DataFrame. Code was:\n{code}"

# ---------------- Gradio functions ----------------
def fn_load(file):
    df, msg = load_file(file)
    if df is None:
        return None, msg
    preview = df.head(5)
    return preview, msg

def fn_ask(query):
    res, msg = ask_question_gemini(query)
    return res, msg

def fn_clear():
    global session_df
    session_df = None
    gc.collect()
    return (
        gr.File.update(value=None),
        gr.Dataframe.update(value=None),
        gr.Textbox.update(value=""),
        gr.Textbox.update(value=""),
    )

# ---------------- UI ----------------
with gr.Blocks() as demo:
    gr.Markdown("# Chat-with-CSV — Gemini-powered (secure API key via Secrets)")
    with gr.Row():
        file_input = gr.File(label="Upload CSV or XLSX (will not be saved)")
        load_btn = gr.Button("Load file")
    preview_table = gr.Dataframe(headers=None, label="Preview (first 5 rows)")
    file_status = gr.Textbox(label="File status")

    query_input = gr.Textbox(label="Ask a question (English)")
    ask_btn = gr.Button("Ask Gemini")
    result_table = gr.Dataframe(headers=None, label="Result")
    status = gr.Textbox(label="Status / Messages")

    clear_btn = gr.Button("Clear / Reset")

    load_btn.click(fn=fn_load, inputs=file_input, outputs=[preview_table, file_status])
    ask_btn.click(fn=fn_ask, inputs=query_input, outputs=[result_table, status])
    clear_btn.click(fn=fn_clear, outputs=[file_input, preview_table, query_input, result_table])

if __name__ == "__main__":
    demo.launch()