triflix commited on
Commit
98ea4f9
·
verified ·
1 Parent(s): f57254f

Upload 7 files

Browse files
Files changed (3) hide show
  1. main.py +96 -100
  2. requirements.txt +2 -1
  3. templates/index.html +11 -4
main.py CHANGED
@@ -1,100 +1,96 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse, JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from io import BytesIO
6
- import base64
7
- import matplotlib.pyplot as plt
8
- import pandas as pd
9
- from google import genai
10
- from google.genai import types
11
- import os
12
-
13
- # ---- User configuration ----
14
- API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyB1jgGCuzg7ELPwNEEwaluQZoZhxhgLmAs")
15
- EXCEL_FILE = "/content/heart_failure_clinical_records_dashboard.xlsx"
16
-
17
- client = genai.Client(api_key=API_KEY)
18
- MODEL = "gemini-2.5-flash-lite"
19
-
20
- # FastAPI setup
21
- app = FastAPI()
22
- app.mount("/static", StaticFiles(directory="static"), name="static")
23
- templates = Jinja2Templates(directory="templates")
24
-
25
- # Load dataset
26
- df = pd.read_excel(EXCEL_FILE)
27
-
28
- def get_metadata(df):
29
- return {
30
- "columns": list(df.columns),
31
- "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
32
- "num_rows": df.shape[0],
33
- "num_cols": df.shape[1],
34
- "null_counts": df.isnull().sum().to_dict(),
35
- "unique_counts": df.nunique().to_dict(),
36
- "sample_rows": df.head(3).to_dict(orient="records")
37
- }
38
-
39
- def generate_plot_code(user_query, metadata):
40
- system_prompt = f"""
41
- You are a Python plotting assistant.
42
- Use the existing DataFrame named df.
43
- Do NOT load any files.
44
- Use only matplotlib or pandas plotting.
45
- Use only the following columns: {metadata['columns']}.
46
- Do NOT explain, do NOT add extra text.
47
- Only produce executable code for plotting the requested chart.
48
- """
49
- user_prompt = f"""
50
- Dataset metadata:
51
- Columns: {metadata['columns']}
52
- Data types: {metadata['dtypes']}
53
- Null counts: {metadata['null_counts']}
54
- Unique counts: {metadata['unique_counts']}
55
- Sample rows: {metadata['sample_rows']}
56
-
57
- User request: {user_query}
58
- """
59
- contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
60
- config = types.GenerateContentConfig(
61
- temperature=0,
62
- max_output_tokens=1000,
63
- thinking_config=types.ThinkingConfig(thinking_budget=0),
64
- system_instruction=[types.Part.from_text(text=system_prompt)]
65
- )
66
-
67
- code = ""
68
- for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
69
- code += chunk.text
70
- return code.replace("```python", "").replace("```", "").strip()
71
-
72
- @app.get("/", response_class=HTMLResponse)
73
- async def home(request: Request):
74
- return templates.TemplateResponse("index.html", {"request": request})
75
-
76
- @app.post("/generate_plot")
77
- async def generate_plot(request: Request):
78
- data = await request.json()
79
- user_query = data.get("query", "")
80
- metadata = get_metadata(df)
81
-
82
- # Generate code
83
- code = generate_plot_code(user_query, metadata)
84
-
85
- # Execute code and generate plot
86
- try:
87
- exec_globals = {"df": df, "plt": plt}
88
- exec(code, exec_globals)
89
- buf = BytesIO()
90
- plt.savefig(buf, format="png")
91
- plt.close()
92
- buf.seek(0)
93
- img_base64 = base64.b64encode(buf.read()).decode("utf-8")
94
- success = True
95
- except Exception as e:
96
- img_base64 = ""
97
- success = False
98
- code += f"\n\n# ERROR: {str(e)}"
99
-
100
- return JSONResponse({"success": success, "plot": img_base64, "code": code})
 
1
+ from fastapi import FastAPI, Request, File, UploadFile, Form
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ from io import BytesIO
6
+ import base64
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ from google import genai
10
+ from google.genai import types
11
+ import os
12
+
13
+ # ---- Configuration ----
14
+ API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_API_KEY")
15
+ MODEL = "gemini-2.5-flash-lite"
16
+
17
+ client = genai.Client(api_key=API_KEY)
18
+
19
+ # FastAPI setup
20
+ app = FastAPI()
21
+ app.mount("/static", StaticFiles(directory="static"), name="static")
22
+ templates = Jinja2Templates(directory="templates")
23
+
24
+ def get_metadata(df):
25
+ return {
26
+ "columns": list(df.columns),
27
+ "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
28
+ "num_rows": df.shape[0],
29
+ "num_cols": df.shape[1],
30
+ "null_counts": df.isnull().sum().to_dict(),
31
+ "unique_counts": df.nunique().to_dict(),
32
+ "sample_rows": df.head(3).to_dict(orient="records")
33
+ }
34
+
35
+ def generate_plot_code(user_query, metadata):
36
+ system_prompt = f"""
37
+ You are a Python plotting assistant.
38
+ Use the existing DataFrame named df.
39
+ Do NOT load any files.
40
+ Use only matplotlib or pandas plotting.
41
+ Use only the following columns: {metadata['columns']}.
42
+ Do NOT explain, do NOT add extra text.
43
+ Only produce executable code for plotting the requested chart.
44
+ """
45
+ user_prompt = f"""
46
+ Dataset metadata:
47
+ Columns: {metadata['columns']}
48
+ Data types: {metadata['dtypes']}
49
+ Null counts: {metadata['null_counts']}
50
+ Unique counts: {metadata['unique_counts']}
51
+ Sample rows: {metadata['sample_rows']}
52
+
53
+ User request: {user_query}
54
+ """
55
+ contents = [types.Content(role="user", parts=[types.Part.from_text(text=user_prompt)])]
56
+ config = types.GenerateContentConfig(
57
+ temperature=0,
58
+ max_output_tokens=1000,
59
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
60
+ system_instruction=[types.Part.from_text(text=system_prompt)]
61
+ )
62
+
63
+ code = ""
64
+ for chunk in client.models.generate_content_stream(model=MODEL, contents=contents, config=config):
65
+ code += chunk.text
66
+ return code.replace("```python", "").replace("```", "").strip()
67
+
68
+ @app.get("/", response_class=HTMLResponse)
69
+ async def home(request: Request):
70
+ return templates.TemplateResponse("index.html", {"request": request})
71
+
72
+ @app.post("/generate_plot_file")
73
+ async def generate_plot_file(file: UploadFile = File(...), query: str = Form(...)):
74
+ # Read uploaded Excel
75
+ df = pd.read_excel(file.file)
76
+ metadata = get_metadata(df)
77
+
78
+ # Generate AI plotting code
79
+ code = generate_plot_code(query, metadata)
80
+
81
+ # Execute code
82
+ try:
83
+ exec_globals = {"df": df, "plt": plt}
84
+ exec(code, exec_globals)
85
+ buf = BytesIO()
86
+ plt.savefig(buf, format="png")
87
+ plt.close()
88
+ buf.seek(0)
89
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
90
+ success = True
91
+ except Exception as e:
92
+ img_base64 = ""
93
+ success = False
94
+ code += f"\n\n# ERROR: {str(e)}"
95
+
96
+ return JSONResponse({"success": success, "plot": img_base64, "code": code})
 
 
 
 
requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn
3
  jinja2
4
  pandas
5
  matplotlib
6
- google-genai
 
 
3
  jinja2
4
  pandas
5
  matplotlib
6
+ google-genai
7
+ python-multipart
templates/index.html CHANGED
@@ -11,6 +11,7 @@
11
  <h1 class="text-2xl font-bold mb-4">AI Plotting App</h1>
12
 
13
  <div class="w-full max-w-2xl bg-white p-4 rounded shadow mb-4">
 
14
  <input id="userQuery" type="text" placeholder="Enter plot request..." class="w-full border p-2 rounded mb-2">
15
  <button onclick="generatePlot()" class="bg-blue-600 text-white px-4 py-2 rounded">Generate Plot</button>
16
  </div>
@@ -27,11 +28,17 @@
27
 
28
  <script>
29
  async function generatePlot() {
30
- const query = document.getElementById("userQuery").value;
31
- const response = await fetch("/generate_plot", {
 
 
 
 
 
 
 
32
  method: "POST",
33
- headers: {"Content-Type": "application/json"},
34
- body: JSON.stringify({query})
35
  });
36
  const data = await response.json();
37
  if(data.success){
 
11
  <h1 class="text-2xl font-bold mb-4">AI Plotting App</h1>
12
 
13
  <div class="w-full max-w-2xl bg-white p-4 rounded shadow mb-4">
14
+ <input id="excelFile" type="file" accept=".xlsx" class="w-full mb-2" />
15
  <input id="userQuery" type="text" placeholder="Enter plot request..." class="w-full border p-2 rounded mb-2">
16
  <button onclick="generatePlot()" class="bg-blue-600 text-white px-4 py-2 rounded">Generate Plot</button>
17
  </div>
 
28
 
29
  <script>
30
  async function generatePlot() {
31
+ const fileInput = document.getElementById("excelFile");
32
+ const queryInput = document.getElementById("userQuery");
33
+ if (!fileInput.files.length) { alert("Please upload a file."); return; }
34
+
35
+ const formData = new FormData();
36
+ formData.append("file", fileInput.files[0]);
37
+ formData.append("query", queryInput.value);
38
+
39
+ const response = await fetch("/generate_plot_file", {
40
  method: "POST",
41
+ body: formData
 
42
  });
43
  const data = await response.json();
44
  if(data.success){