triflix commited on
Commit
585b9ad
·
verified ·
1 Parent(s): d9056b8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +96 -96
main.py CHANGED
@@ -1,96 +1,96 @@
1
- from fastapi import FastAPI, Request, File, UploadFile, Form
2
- from fastapi.responses import HTMLResponse, JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from io import BytesIO
6
- import base64
7
- import matplotlib.pyplot as plt
8
- import pandas as pd
9
- from google import genai
10
- from google.genai import types
11
- import os
12
-
13
- # ---- Configuration ----
14
- API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_API_KEY")
15
- MODEL = "gemini-2.5-flash-lite"
16
-
17
- client = genai.Client(api_key=API_KEY)
18
-
19
- # FastAPI setup
20
- app = FastAPI()
21
- app.mount("/static", StaticFiles(directory="static"), name="static")
22
- templates = Jinja2Templates(directory="templates")
23
-
24
- def get_metadata(df):
25
- return {
26
- "columns": list(df.columns),
27
- "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
28
- "num_rows": df.shape[0],
29
- "num_cols": df.shape[1],
30
- "null_counts": df.isnull().sum().to_dict(),
31
- "unique_counts": df.nunique().to_dict(),
32
- "sample_rows": df.head(3).to_dict(orient="records")
33
- }
34
-
35
- def generate_plot_code(user_query, metadata):
36
- system_prompt = f"""
37
- You are a Python plotting assistant.
38
- Use the existing DataFrame named df.
39
- Do NOT load any files.
40
- Use only matplotlib or pandas plotting.
41
- Use only the following columns: {metadata['columns']}.
42
- Do NOT explain, do NOT add extra text.
43
- Only produce executable code for plotting the requested chart.
44
- """
45
- user_prompt = f"""
46
- Dataset metadata:
47
- Columns: {metadata['columns']}
48
- Data types: {metadata['dtypes']}
49
- Null counts: {metadata['null_counts']}
50
- Unique counts: {metadata['unique_counts']}
51
- Sample rows: {metadata['sample_rows']}
52
-
53
- User request: {user_query}
54
- """
55
- contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
56
- config = types.GenerateContentConfig(
57
- temperature=0,
58
- max_output_tokens=1000,
59
- thinking_config=types.ThinkingConfig(thinking_budget=0),
60
- system_instruction=[types.Part.from_text(text=system_prompt)]
61
- )
62
-
63
- code = ""
64
- for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
65
- code += chunk.text
66
- return code.replace("```python", "").replace("```", "").strip()
67
-
68
- @app.get("/", response_class=HTMLResponse)
69
- async def home(request: Request):
70
- return templates.TemplateResponse("index.html", {"request": request})
71
-
72
- @app.post("/generate_plot_file")
73
- async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)):
74
- # Read uploaded Excel
75
- df = pd.read_excel(file.file)
76
- metadata = get_metadata(df)
77
-
78
- # Generate AI plotting code
79
- code = generate_plot_code(query, metadata)
80
-
81
- # Execute code
82
- try:
83
- exec_globals = {"df": df, "plt": plt}
84
- exec(code, exec_globals)
85
- buf = BytesIO()
86
- plt.savefig(buf, format="png")
87
- plt.close()
88
- buf.seek(0)
89
- img_base64 = base64.b64encode(buf.read()).decode("utf-8")
90
- success = True
91
- except Exception as e:
92
- img_base64 = ""
93
- success = False
94
- code += f"\n\n# ERROR: {str(e)}"
95
-
96
- return JSONResponse({"success": success, "plot": img_base64, "code": code})
 
1
+ from fastapi import FastAPI, Request, File, UploadFile, Form
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ from io import BytesIO
6
+ import base64
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ from google import genai
10
+ from google.genai import types
11
+ import os
12
+
13
+ # ---- Configuration ----
14
+ API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs")
15
+ MODEL = "gemini-2.5-flash-lite"
16
+
17
+ client = genai.Client(api_key=API_KEY)
18
+
19
+ # FastAPI setup
20
+ app = FastAPI()
21
+ app.mount("/static", StaticFiles(directory="static"), name="static")
22
+ templates = Jinja2Templates(directory="templates")
23
+
24
+ def get_metadata(df):
25
+ return {
26
+ "columns": list(df.columns),
27
+ "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
28
+ "num_rows": df.shape[0],
29
+ "num_cols": df.shape[1],
30
+ "null_counts": df.isnull().sum().to_dict(),
31
+ "unique_counts": df.nunique().to_dict(),
32
+ "sample_rows": df.head(3).to_dict(orient="records")
33
+ }
34
+
35
+ def generate_plot_code(user_query, metadata):
36
+ system_prompt = f"""
37
+ You are a Python plotting assistant.
38
+ Use the existing DataFrame named df.
39
+ Do NOT load any files.
40
+ Use only matplotlib or pandas plotting.
41
+ Use only the following columns: {metadata['columns']}.
42
+ Do NOT explain, do NOT add extra text.
43
+ Only produce executable code for plotting the requested chart.
44
+ """
45
+ user_prompt = f"""
46
+ Dataset metadata:
47
+ Columns: {metadata['columns']}
48
+ Data types: {metadata['dtypes']}
49
+ Null counts: {metadata['null_counts']}
50
+ Unique counts: {metadata['unique_counts']}
51
+ Sample rows: {metadata['sample_rows']}
52
+
53
+ User request: {user_query}
54
+ """
55
+ contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
56
+ config = types.GenerateContentConfig(
57
+ temperature=0,
58
+ max_output_tokens=1000,
59
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
60
+ system_instruction=[types.Part.from_text(text=system_prompt)]
61
+ )
62
+
63
+ code = ""
64
+ for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
65
+ code += chunk.text
66
+ return code.replace("```python", "").replace("```", "").strip()
67
+
68
+ @app.get("/", response_class=HTMLResponse)
69
+ async def home(request: Request):
70
+ return templates.TemplateResponse("index.html", {"request": request})
71
+
72
+ @app.post("/generate_plot_file")
73
+ async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)):
74
+ # Read uploaded Excel
75
+ df = pd.read_excel(file.file)
76
+ metadata = get_metadata(df)
77
+
78
+ # Generate AI plotting code
79
+ code = generate_plot_code(query, metadata)
80
+
81
+ # Execute code
82
+ try:
83
+ exec_globals = {"df": df, "plt": plt}
84
+ exec(code, exec_globals)
85
+ buf = BytesIO()
86
+ plt.savefig(buf, format="png")
87
+ plt.close()
88
+ buf.seek(0)
89
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
90
+ success = True
91
+ except Exception as e:
92
+ img_base64 = ""
93
+ success = False
94
+ code += f"\n\n# ERROR: {str(e)}"
95
+
96
+ return JSONResponse({"success": success, "plot": img_base64, "code": code})