| from fastapi import FastAPI, Request, File, UploadFile, Form
|
| from fastapi.responses import HTMLResponse, JSONResponse
|
| from fastapi.staticfiles import StaticFiles
|
| from fastapi.templating import Jinja2Templates
|
| from io import BytesIO
|
| import base64
|
| import matplotlib.pyplot as plt
|
| import pandas as pd
|
| from google import genai
|
| from google.genai import types
|
| import os
|
|
|
|
|
| API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_API_KEY")
|
| MODEL = "gemini-2.5-flash-lite"
|
|
|
| client = genai.Client(api_key=API_KEY)
|
|
|
|
|
| app = FastAPI()
|
| app.mount("/static", StaticFiles(directory="static"), name="static")
|
| templates = Jinja2Templates(directory="templates")
|
|
|
| def get_metadata(df):
|
| return {
|
| "columns": list(df.columns),
|
| "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
|
| "num_rows": df.shape[0],
|
| "num_cols": df.shape[1],
|
| "null_counts": df.isnull().sum().to_dict(),
|
| "unique_counts": df.nunique().to_dict(),
|
| "sample_rows": df.head(3).to_dict(orient="records")
|
| }
|
|
|
| def generate_plot_code(user_query, metadata):
|
| system_prompt = f"""
|
| You are a Python plotting assistant.
|
| Use the existing DataFrame named df.
|
| Do NOT load any files.
|
| Use only matplotlib or pandas plotting.
|
| Use only the following columns: {metadata['columns']}.
|
| Do NOT explain, do NOT add extra text.
|
| Only produce executable code for plotting the requested chart.
|
| """
|
| user_prompt = f"""
|
| Dataset metadata:
|
| Columns: {metadata['columns']}
|
| Data types: {metadata['dtypes']}
|
| Null counts: {metadata['null_counts']}
|
| Unique counts: {metadata['unique_counts']}
|
| Sample rows: {metadata['sample_rows']}
|
|
|
| User request: {user_query}
|
| """
|
| contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
|
| config = types.GenerateContentConfig(
|
| temperature=0,
|
| max_output_tokens=1000,
|
| thinking_config=types.ThinkingConfig(thinking_budget=0),
|
| system_instruction=[types.Part.from_text(text=system_prompt)]
|
| )
|
|
|
| code = ""
|
| for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
|
| code += chunk.text
|
| return code.replace("```python", "").replace("```", "").strip()
|
|
|
| @app.get("/", response_class=HTMLResponse)
|
| async def home(request: Request):
|
| return templates.TemplateResponse("index.html", {"request": request})
|
|
|
| @app.post("/generate_plot_file")
|
| async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)):
|
|
|
| df = pd.read_excel(file.file)
|
| metadata = get_metadata(df)
|
|
|
|
|
| code = generate_plot_code(query, metadata)
|
|
|
|
|
| try:
|
| exec_globals = {"df": df, "plt": plt}
|
| exec(code, exec_globals)
|
| buf = BytesIO()
|
| plt.savefig(buf, format="png")
|
| plt.close()
|
| buf.seek(0)
|
| img_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
| success = True
|
| except Exception as e:
|
| img_base64 = ""
|
| success = False
|
| code += f"\n\n# ERROR: {str(e)}"
|
|
|
| return JSONResponse({"success": success, "plot": img_base64, "code": code})
|
|
|