triflix commited on
Commit
f57254f
·
verified ·
1 Parent(s): b1aea0a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +100 -100
main.py CHANGED
@@ -1,100 +1,100 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse, JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from io import BytesIO
6
- import base64
7
- import matplotlib.pyplot as plt
8
- import pandas as pd
9
- from google import genai
10
- from google.genai import types
11
- import os
12
-
13
- # ---- User configuration ----
14
- API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_API_KEY")
15
- EXCEL_FILE = "/content/heart_failure_clinical_records_dashboard.xlsx"
16
-
17
- client = genai.Client(api_key=API_KEY)
18
- MODEL = "gemini-2.5-flash-lite"
19
-
20
- # FastAPI setup
21
- app = FastAPI()
22
- app.mount("/static", StaticFiles(directory="static"), name="static")
23
- templates = Jinja2Templates(directory="templates")
24
-
25
- # Load dataset
26
- df = pd.read_excel(EXCEL_FILE)
27
-
28
- def get_metadata(df):
29
- return {
30
- "columns": list(df.columns),
31
- "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
32
- "num_rows": df.shape[0],
33
- "num_cols": df.shape[1],
34
- "null_counts": df.isnull().sum().to_dict(),
35
- "unique_counts": df.nunique().to_dict(),
36
- "sample_rows": df.head(3).to_dict(orient="records")
37
- }
38
-
39
- def generate_plot_code(user_query, metadata):
40
- system_prompt = f"""
41
- You are a Python plotting assistant.
42
- Use the existing DataFrame named df.
43
- Do NOT load any files.
44
- Use only matplotlib or pandas plotting.
45
- Use only the following columns: {metadata['columns']}.
46
- Do NOT explain, do NOT add extra text.
47
- Only produce executable code for plotting the requested chart.
48
- """
49
- user_prompt = f"""
50
- Dataset metadata:
51
- Columns: {metadata['columns']}
52
- Data types: {metadata['dtypes']}
53
- Null counts: {metadata['null_counts']}
54
- Unique counts: {metadata['unique_counts']}
55
- Sample rows: {metadata['sample_rows']}
56
-
57
- User request: {user_query}
58
- """
59
- contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
60
- config = types.GenerateContentConfig(
61
- temperature=0,
62
- max_output_tokens=1000,
63
- thinking_config=types.ThinkingConfig(thinking_budget=0),
64
- system_instruction=[types.Part.from_text(text=system_prompt)]
65
- )
66
-
67
- code = ""
68
- for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
69
- code += chunk.text
70
- return code.replace("```python", "").replace("```", "").strip()
71
-
72
- @app.get("/", response_class=HTMLResponse)
73
- async def home(request: Request):
74
- return templates.TemplateResponse("index.html", {"request": request})
75
-
76
- @app.post("/generate_plot")
77
- async def generate_plot(request: Request):
78
- data = await request.json()
79
- user_query = data.get("query", "")
80
- metadata = get_metadata(df)
81
-
82
- # Generate code
83
- code = generate_plot_code(user_query, metadata)
84
-
85
- # Execute code and generate plot
86
- try:
87
- exec_globals = {"df": df, "plt": plt}
88
- exec(code, exec_globals)
89
- buf = BytesIO()
90
- plt.savefig(buf, format="png")
91
- plt.close()
92
- buf.seek(0)
93
- img_base64 = base64.b64encode(buf.read()).decode("utf-8")
94
- success = True
95
- except Exception as e:
96
- img_base64 = ""
97
- success = False
98
- code += f"\n\n# ERROR: {str(e)}"
99
-
100
- return JSONResponse({"success": success, "plot": img_base64, "code": code})
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ from io import BytesIO
6
+ import base64
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ from google import genai
10
+ from google.genai import types
11
+ import os
12
+
13
+ # ---- User configuration ----
14
+ API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs")
15
+ EXCEL_FILE = "/content/heart_failure_clinical_records_dashboard.xlsx"
16
+
17
+ client = genai.Client(api_key=API_KEY)
18
+ MODEL = "gemini-2.5-flash-lite"
19
+
20
+ # FastAPI setup
21
+ app = FastAPI()
22
+ app.mount("/static", StaticFiles(directory="static"), name="static")
23
+ templates = Jinja2Templates(directory="templates")
24
+
25
+ # Load dataset
26
+ df = pd.read_excel(EXCEL_FILE)
27
+
28
+ def get_metadata(df):
29
+ return {
30
+ "columns": list(df.columns),
31
+ "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
32
+ "num_rows": df.shape[0],
33
+ "num_cols": df.shape[1],
34
+ "null_counts": df.isnull().sum().to_dict(),
35
+ "unique_counts": df.nunique().to_dict(),
36
+ "sample_rows": df.head(3).to_dict(orient="records")
37
+ }
38
+
39
+ def generate_plot_code(user_query, metadata):
40
+ system_prompt = f"""
41
+ You are a Python plotting assistant.
42
+ Use the existing DataFrame named df.
43
+ Do NOT load any files.
44
+ Use only matplotlib or pandas plotting.
45
+ Use only the following columns: {metadata['columns']}.
46
+ Do NOT explain, do NOT add extra text.
47
+ Only produce executable code for plotting the requested chart.
48
+ """
49
+ user_prompt = f"""
50
+ Dataset metadata:
51
+ Columns: {metadata['columns']}
52
+ Data types: {metadata['dtypes']}
53
+ Null counts: {metadata['null_counts']}
54
+ Unique counts: {metadata['unique_counts']}
55
+ Sample rows: {metadata['sample_rows']}
56
+
57
+ User request: {user_query}
58
+ """
59
+ contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
60
+ config = types.GenerateContentConfig(
61
+ temperature=0,
62
+ max_output_tokens=1000,
63
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
64
+ system_instruction=[types.Part.from_text(text=system_prompt)]
65
+ )
66
+
67
+ code = ""
68
+ for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
69
+ code += chunk.text
70
+ return code.replace("```python", "").replace("```", "").strip()
71
+
72
+ @app.get("/", response_class=HTMLResponse)
73
+ async def home(request: Request):
74
+ return templates.TemplateResponse("index.html", {"request": request})
75
+
76
+ @app.post("/generate_plot")
77
+ async def generate_plot(request: Request):
78
+ data = await request.json()
79
+ user_query = data.get("query", "")
80
+ metadata = get_metadata(df)
81
+
82
+ # Generate code
83
+ code = generate_plot_code(user_query, metadata)
84
+
85
+ # Execute code and generate plot
86
+ try:
87
+ exec_globals = {"df": df, "plt": plt}
88
+ exec(code, exec_globals)
89
+ buf = BytesIO()
90
+ plt.savefig(buf, format="png")
91
+ plt.close()
92
+ buf.seek(0)
93
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
94
+ success = True
95
+ except Exception as e:
96
+ img_base64 = ""
97
+ success = False
98
+ code += f"\n\n# ERROR: {str(e)}"
99
+
100
+ return JSONResponse({"success": success, "plot": img_base64, "code": code})