File size: 1,930 Bytes
583755a
7509c8d
 
583755a
 
 
 
 
9ac0790
 
79a2070
f38d4d9
 
0f569ec
7509c8d
 
0f569ec
583755a
 
 
 
 
 
 
7a6576c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583755a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor, AutoModelForImageTextToText

import torch
import json
import re
import os


hf_token = os.environ.get("HUGGINGFACE_TOKEN")
model_id = "google/gemma-3n-E4B"


cache_dir = "/tmp/hf_cache"
tokenizer = AutoProcessor.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
model = AutoModelForImageTextToText.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)


def call_llm(prompt):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = model.generate(**inputs, max_new_tokens=2048)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

visual_prompt = """
You are a data visualization expert. You will be given a summary of a cleaned dataset.
Your tasks:
1. Suggest 3–5 interesting visualizations that would help uncover patterns or relationships.
2. For each, describe what insight it may reveal.
3. For each, write Python code using pandas/seaborn/matplotlib to generate the plot. Use 'df' as the dataframe and be precise with column names.
4. Always be careful and precise with column names 
Output JSON in this exact format:
{
  "visualizations": [
    {
      "title": "Histogram of Age",
      "description": "Shows the distribution of age",
      "code": "sns.histplot(df['age'], kde=True); plt.title('Age Distribution'); plt.savefig('charts/age.png'); plt.clf()"
    },
    ...
  ]
}
Dataset Summary:
{column_data}
"""

def generate_visual_plan(column_data):
    prompt = visual_prompt.format(column_data=json.dumps(column_data, indent=2))
    response = call_llm(prompt)

    match = re.search(r"\{.*\}", response, re.DOTALL)
    if match:
        try:
            parsed = json.loads(match.group(0))
            return parsed["visualizations"]
        except:
            print("Failed to parse visualization JSON.")
            print(response)
    return []