Osnly commited on
Commit
583755a
·
verified ·
1 Parent(s): 49436c8

Update src/visual_insight.py

Browse files
Files changed (1) hide show
  1. src/visual_insight.py +55 -52
src/visual_insight.py CHANGED
@@ -1,52 +1,55 @@
1
- # visual_insight.py
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
- import json
5
- import re
6
-
7
- model_id = "google/gemma-3n-E4B-it"
8
- tokenizer = AutoTokenizer.from_pretrained(model_id)
9
- model = AutoModelForCausalLM.from_pretrained(model_id)
10
-
11
- def call_llm(prompt):
12
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
13
- outputs = model.generate(**inputs, max_new_tokens=2048)
14
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
15
-
16
- visual_prompt = """
17
- You are a data visualization expert. You will be given a summary of a cleaned dataset.
18
-
19
- Your tasks:
20
- 1. Suggest 3–5 interesting visualizations that would help uncover patterns or relationships.
21
- 2. For each, describe what insight it may reveal.
22
- 3. For each, write Python code using pandas/seaborn/matplotlib to generate the plot. Use 'df' as the dataframe and be precise with column names.
23
- 4. Always be careful and precise with column names
24
- Output JSON in this exact format:
25
- {
26
- "visualizations": [
27
- {
28
- "title": "Histogram of Age",
29
- "description": "Shows the distribution of age",
30
- "code": "sns.histplot(df['age'], kde=True); plt.title('Age Distribution'); plt.savefig('charts/age.png'); plt.clf()"
31
- },
32
- ...
33
- ]
34
- }
35
-
36
- Dataset Summary:
37
- {column_data}
38
- """
39
-
40
- def generate_visual_plan(column_data):
41
- prompt = visual_prompt.format(column_data=json.dumps(column_data, indent=2))
42
- response = call_llm(prompt)
43
-
44
- match = re.search(r"\{.*\}", response, re.DOTALL)
45
- if match:
46
- try:
47
- parsed = json.loads(match.group(0))
48
- return parsed["visualizations"]
49
- except:
50
- print("Failed to parse visualization JSON.")
51
- print(response)
52
- return []
 
 
 
 
1
+ # visual_insight.py
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import json
5
+ import re
6
+ import os
7
+
8
+
9
+ model_id = "google/gemma-3n-E4B-it"
10
+ hf_token = os.environ.get("HUGGINGFACE_TOKEN")
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
12
+ model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token)
13
+
14
+ def call_llm(prompt):
15
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
16
+ outputs = model.generate(**inputs, max_new_tokens=2048)
17
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
18
+
19
+ visual_prompt = """
20
+ You are a data visualization expert. You will be given a summary of a cleaned dataset.
21
+
22
+ Your tasks:
23
+ 1. Suggest 3–5 interesting visualizations that would help uncover patterns or relationships.
24
+ 2. For each, describe what insight it may reveal.
25
+ 3. For each, write Python code using pandas/seaborn/matplotlib to generate the plot. Use 'df' as the dataframe and be precise with column names.
26
+ 4. Always be careful and precise with column names
27
+ Output JSON in this exact format:
28
+ {
29
+ "visualizations": [
30
+ {
31
+ "title": "Histogram of Age",
32
+ "description": "Shows the distribution of age",
33
+ "code": "sns.histplot(df['age'], kde=True); plt.title('Age Distribution'); plt.savefig('charts/age.png'); plt.clf()"
34
+ },
35
+ ...
36
+ ]
37
+ }
38
+
39
+ Dataset Summary:
40
+ {column_data}
41
+ """
42
+
43
+ def generate_visual_plan(column_data):
44
+ prompt = visual_prompt.format(column_data=json.dumps(column_data, indent=2))
45
+ response = call_llm(prompt)
46
+
47
+ match = re.search(r"\{.*\}", response, re.DOTALL)
48
+ if match:
49
+ try:
50
+ parsed = json.loads(match.group(0))
51
+ return parsed["visualizations"]
52
+ except:
53
+ print("Failed to parse visualization JSON.")
54
+ print(response)
55
+ return []