from fastapi import FastAPI, Request, File, UploadFile, Form from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from io import BytesIO import base64 import matplotlib.pyplot as plt import pandas as pd from google import genai from google.genai import types import os # ---- Configuration ---- API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs") MODEL = "gemini-2.5-flash-lite" client = genai.Client(api_key=API_KEY) # FastAPI setup app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") def get_metadata(df): return { "columns": list(df.columns), "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(), "num_rows": df.shape[0], "num_cols": df.shape[1], "null_counts": df.isnull().sum().to_dict(), "unique_counts": df.nunique().to_dict(), "sample_rows": df.head(3).to_dict(orient="records") } def generate_plot_code(user_query, metadata): system_prompt = f""" You are a Python plotting assistant. Use the existing DataFrame named df. Do NOT load any files. Use only matplotlib or pandas plotting. Use only the following columns: {metadata['columns']}. Do NOT explain, do NOT add extra text. Only produce executable code for plotting the requested chart. """ user_prompt = f""" Dataset metadata: Columns: {metadata['columns']} Data types: {metadata['dtypes']} Null counts: {metadata['null_counts']} Unique counts: {metadata['unique_counts']} Sample rows: {metadata['sample_rows']} User request: {user_query} """ contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])] config = types.GenerateContentConfig( temperature=0, max_output_tokens=1000, thinking_config=types.ThinkingConfig(thinking_budget=0), system_instruction=[types.Part.from_text(text=system_prompt)] ) code = "" for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config): code += chunk.text return code.replace("```python", "").replace("```", "").strip() @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/generate_plot_file") async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)): # Read uploaded Excel df = pd.read_excel(file.file) metadata = get_metadata(df) # Generate AI plotting code code = generate_plot_code(query, metadata) # Execute code try: exec_globals = {"df": df, "plt": plt} exec(code, exec_globals) buf = BytesIO() plt.savefig(buf, format="png") plt.close() buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") success = True except Exception as e: img_base64 = "" success = False code += f"\n\n# ERROR: {str(e)}" return JSONResponse({"success": success, "plot": img_base64, "code": code})