Osnly commited on
Commit
f38d4d9
·
verified ·
1 Parent(s): 1a8550a

Update src/visual_insight.py

Browse files
Files changed (1) hide show
  1. src/visual_insight.py +11 -25
src/visual_insight.py CHANGED
@@ -1,14 +1,19 @@
1
- # visual_insight.py
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import json
5
  import re
6
  import os
7
 
8
-
9
  model_id = "google/gemma-3n-E4B-it"
10
- hf_token = os.environ.get("HUGGINGFACE_TOKEN")
11
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
 
 
 
 
 
 
 
12
  model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token)
13
 
14
  def call_llm(prompt):
@@ -17,27 +22,8 @@ def call_llm(prompt):
17
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
18
 
19
  visual_prompt = """
20
- You are a data visualization expert. You will be given a summary of a cleaned dataset.
21
-
22
- Your tasks:
23
- 1. Suggest 3–5 interesting visualizations that would help uncover patterns or relationships.
24
- 2. For each, describe what insight it may reveal.
25
- 3. For each, write Python code using pandas/seaborn/matplotlib to generate the plot. Use 'df' as the dataframe and be precise with column names.
26
- 4. Always be careful and precise with column names
27
- Output JSON in this exact format:
28
- {
29
- "visualizations": [
30
- {
31
- "title": "Histogram of Age",
32
- "description": "Shows the distribution of age",
33
- "code": "sns.histplot(df['age'], kde=True); plt.title('Age Distribution'); plt.savefig('charts/age.png'); plt.clf()"
34
- },
35
- ...
36
- ]
37
- }
38
-
39
- Dataset Summary:
40
- {column_data}
41
  """
42
 
43
  def generate_visual_plan(column_data):
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
3
  import json
4
  import re
5
  import os
6
 
 
7
  model_id = "google/gemma-3n-E4B-it"
8
+
9
+ # Set Hugging Face cache directory
10
+ HF_CACHE_DIR = "./hf_cache"
11
+ os.environ["HF_HOME"] = HF_CACHE_DIR
12
+ os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
13
+ os.makedirs(HF_CACHE_DIR, exist_ok=True)
14
+
15
+ hf_token = os.environ.get("HUGGINGFACE_TOKEN")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, use_auth_token=True)
17
  model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token)
18
 
19
  def call_llm(prompt):
 
22
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
23
 
24
  visual_prompt = """
25
+ You are a data visualization expert...
26
+ [TRUNCATED for brevity, use your full original template]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
 
29
  def generate_visual_plan(column_data):