|
|
from fastapi import FastAPI, Request, File, UploadFile, Form |
|
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
import matplotlib.pyplot as plt |
|
|
import pandas as pd |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
import os |
|
|
|
|
|
|
|
|
API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs") |
|
|
MODEL = "gemini-2.5-flash-lite" |
|
|
|
|
|
client = genai.Client(api_key=API_KEY) |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
def get_metadata(df): |
|
|
return { |
|
|
"columns": list(df.columns), |
|
|
"dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(), |
|
|
"num_rows": df.shape[0], |
|
|
"num_cols": df.shape[1], |
|
|
"null_counts": df.isnull().sum().to_dict(), |
|
|
"unique_counts": df.nunique().to_dict(), |
|
|
"sample_rows": df.head(3).to_dict(orient="records") |
|
|
} |
|
|
|
|
|
def generate_plot_code(user_query, metadata): |
|
|
system_prompt = f""" |
|
|
You are a Python plotting assistant. |
|
|
Use the existing DataFrame named df. |
|
|
Do NOT load any files. |
|
|
Use only matplotlib or pandas plotting. |
|
|
Use only the following columns: {metadata['columns']}. |
|
|
Do NOT explain, do NOT add extra text. |
|
|
Only produce executable code for plotting the requested chart. |
|
|
""" |
|
|
user_prompt = f""" |
|
|
Dataset metadata: |
|
|
Columns: {metadata['columns']} |
|
|
Data types: {metadata['dtypes']} |
|
|
Null counts: {metadata['null_counts']} |
|
|
Unique counts: {metadata['unique_counts']} |
|
|
Sample rows: {metadata['sample_rows']} |
|
|
|
|
|
User request: {user_query} |
|
|
""" |
|
|
contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])] |
|
|
config = types.GenerateContentConfig( |
|
|
temperature=0, |
|
|
max_output_tokens=1000, |
|
|
thinking_config=types.ThinkingConfig(thinking_budget=0), |
|
|
system_instruction=[types.Part.from_text(text=system_prompt)] |
|
|
) |
|
|
|
|
|
code = "" |
|
|
for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config): |
|
|
code += chunk.text |
|
|
return code.replace("```python", "").replace("```", "").strip() |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def home(request: Request): |
|
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
|
@app.post("/generate_plot_file") |
|
|
async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)): |
|
|
|
|
|
df = pd.read_excel(file.file) |
|
|
metadata = get_metadata(df) |
|
|
|
|
|
|
|
|
code = generate_plot_code(query, metadata) |
|
|
|
|
|
|
|
|
try: |
|
|
exec_globals = {"df": df, "plt": plt} |
|
|
exec(code, exec_globals) |
|
|
buf = BytesIO() |
|
|
plt.savefig(buf, format="png") |
|
|
plt.close() |
|
|
buf.seek(0) |
|
|
img_base64 = base64.b64encode(buf.read()).decode("utf-8") |
|
|
success = True |
|
|
except Exception as e: |
|
|
img_base64 = "" |
|
|
success = False |
|
|
code += f"\n\n# ERROR: {str(e)}" |
|
|
|
|
|
return JSONResponse({"success": success, "plot": img_base64, "code": code}) |
|
|
|