File size: 3,178 Bytes
585b9ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from fastapi import FastAPI, Request, File, UploadFile, Form
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from io import BytesIO
import base64
import matplotlib.pyplot as plt
import pandas as pd
from google import genai
from google.genai import types
import os
# ---- Configuration ----
API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs")
MODEL = "gemini-2.5-flash-lite"
client = genai.Client(api_key=API_KEY)
# FastAPI setup
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
def get_metadata(df):
return {
"columns": list(df.columns),
"dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
"num_rows": df.shape[0],
"num_cols": df.shape[1],
"null_counts": df.isnull().sum().to_dict(),
"unique_counts": df.nunique().to_dict(),
"sample_rows": df.head(3).to_dict(orient="records")
}
def generate_plot_code(user_query, metadata):
system_prompt = f"""
You are a Python plotting assistant.
Use the existing DataFrame named df.
Do NOT load any files.
Use only matplotlib or pandas plotting.
Use only the following columns: {metadata['columns']}.
Do NOT explain, do NOT add extra text.
Only produce executable code for plotting the requested chart.
"""
user_prompt = f"""
Dataset metadata:
Columns: {metadata['columns']}
Data types: {metadata['dtypes']}
Null counts: {metadata['null_counts']}
Unique counts: {metadata['unique_counts']}
Sample rows: {metadata['sample_rows']}
User request: {user_query}
"""
contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
config = types.GenerateContentConfig(
temperature=0,
max_output_tokens=1000,
thinking_config=types.ThinkingConfig(thinking_budget=0),
system_instruction=[types.Part.from_text(text=system_prompt)]
)
code = ""
for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
code += chunk.text
return code.replace("```python", "").replace("```", "").strip()
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/generate_plot_file")
async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)):
# Read uploaded Excel
df = pd.read_excel(file.file)
metadata = get_metadata(df)
# Generate AI plotting code
code = generate_plot_code(query, metadata)
# Execute code
try:
exec_globals = {"df": df, "plt": plt}
exec(code, exec_globals)
buf = BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
success = True
except Exception as e:
img_base64 = ""
success = False
code += f"\n\n# ERROR: {str(e)}"
return JSONResponse({"success": success, "plot": img_base64, "code": code})
|