Victor Dibia commited on
Commit
61358ce
·
1 Parent(s): 8d3cc3c

recommender update

Browse files
MANIFEST.in CHANGED
@@ -1,2 +1,4 @@
1
  recursive-include lida/web/ui *
2
- recursive-exclude notebooks *
 
 
 
1
  recursive-include lida/web/ui *
2
+ recursive-exclude notebooks *
3
+ recursive-exclude docs *
4
+ recursive-exclude tests *
lida/modules/executor.py CHANGED
@@ -41,6 +41,8 @@ def preprocess_code(code: str) -> str:
41
  code = code[index:]
42
 
43
  code = code.replace("```", "")
 
 
44
  return code
45
 
46
 
 
41
  code = code[index:]
42
 
43
  code = code.replace("```", "")
44
+ if "chart = plot(data)" not in code:
45
+ code = code + "\nchart = plot(data)"
46
  return code
47
 
48
 
lida/modules/manager.py CHANGED
@@ -219,6 +219,7 @@ class Manager(object):
219
  self,
220
  code,
221
  summary: Summary,
 
222
  textgen_config: TextGenerationConfig = TextGenerationConfig(),
223
  library: str = "seaborn",
224
  ):
@@ -237,6 +238,7 @@ class Manager(object):
237
  return self.recommender.generate(
238
  code=code,
239
  summary=summary,
 
240
  textgen_config=textgen_config,
241
  text_gen=self.text_gen,
242
  library=library,
 
219
  self,
220
  code,
221
  summary: Summary,
222
+ n=4,
223
  textgen_config: TextGenerationConfig = TextGenerationConfig(),
224
  library: str = "seaborn",
225
  ):
 
238
  return self.recommender.generate(
239
  code=code,
240
  summary=summary,
241
+ n=n,
242
  textgen_config=textgen_config,
243
  text_gen=self.text_gen,
244
  library=library,
lida/modules/scaffold.py CHANGED
@@ -17,7 +17,7 @@ class ChartScaffold(object):
17
  pass
18
 
19
  def get_template(self, goal: Goal, library: str):
20
- mpl_pre = f"Set chart title to {goal.question}. If the solution requires a single value (e.g. max, min, median, first, last etc), ALWAYS add a line (axvline or axhline) to the chart, ALWAYS with a legend containing the single value (formatted with 0.2F). If using a <field> where semantic_type=date, YOU MUST APPLY the following transform before using that column i) convert date fields to date types using data[''] = pd.to_datetime(data[<field>], errors='coerce'), ALWAYS use errors='coerce' ii) drop the rows with NaT values data = data[pd.notna(data[<field>])] iii) convert field to right time format for plotting. ALWAYS make sure the x-axis labels are legible (e.g., rotate when needed). Use BaseMap for charts that require a map. Given the dataset summary, the plot(data) method should generate a {library} chart ({goal.visualization}) that addresses this goal: {goal.question}. The plot method must return a matplotlib object. Think step by step. \n"
21
 
22
  if library == "matplotlib":
23
  instructions = {"role": "assistant", "content": mpl_pre}
 
17
  pass
18
 
19
  def get_template(self, goal: Goal, library: str):
20
+ mpl_pre = f"Set chart title to {goal.question}. If the solution requires a single value (e.g. max, min, median, first, last etc), ALWAYS add a line (axvline or axhline) to the chart, ALWAYS with a legend containing the single value (formatted with 0.2F). If using a <field> where semantic_type=date, YOU MUST APPLY the following transform before using that column i) convert date fields to date types using data[''] = pd.to_datetime(data[<field>], errors='coerce'), ALWAYS use errors='coerce' ii) drop the rows with NaT values data = data[pd.notna(data[<field>])] iii) convert field to right time format for plotting. ALWAYS make sure the x-axis labels are legible (e.g., rotate when needed). Use BaseMap for charts that require a map. Given the dataset summary, the plot(data) method should generate a {library} chart ({goal.visualization}) that addresses this goal: {goal.question}. Do not include plt.show(). The plot method must return a matplotlib object. Think step by step. \n"
21
 
22
  if library == "matplotlib":
23
  instructions = {"role": "assistant", "content": mpl_pre}
lida/modules/viz/vizrecommender.py CHANGED
@@ -1,4 +1,6 @@
1
-
 
 
2
  from lida.modules.scaffold import ChartScaffold
3
  from llmx import TextGenerator, TextGenerationConfig, TextGenerationResponse
4
  # from lida.modules.scaffold import ChartScaffold
@@ -6,9 +8,16 @@ from lida.datamodel import Goal, Summary
6
 
7
 
8
  system_prompt = """
9
- You are a helpful assistant highly skilled in recommending a DIVERSE set of visualizations. Your input is an existing visualization, and a summary of a dataset and an example visualization goal. Given this input, your task is to recommend an additional DIVERSE visualization that a user may be interesting to a user. Consider different types of valid aggregations, chart types, and use different variables from the data summary. THE CODE YOU GENERATE MUST BE CORRECT AND FOLLOW VISUALIZATION BEST PRACTICES. You MUST return a full program. DO NOT include any preamble text. Do not include explanations or prose.
 
 
 
 
10
  """
11
 
 
 
 
12
 
13
  class VizRecommender(object):
14
  """Generate visualizations from prompt"""
@@ -22,6 +31,7 @@ class VizRecommender(object):
22
  self, code: str, summary: Summary,
23
  textgen_config: TextGenerationConfig,
24
  text_gen: TextGenerator,
 
25
  library='altair'):
26
  """Recommend a code spec based on existing visualization"""
27
 
@@ -36,12 +46,24 @@ class VizRecommender(object):
36
  {"role": "system", "content": f"The dataset summary is : {summary}"},
37
  {"role": "system",
38
  "content":
39
- f"The original visualization code is: {code}. You MUST use only the {library} library with the following instructions {library_instructions}. The resulting code MUST use the following template {library_template}."},
40
- {"role": "user", "content": "Now write code for an additional visualizations that a user may be interested in given the goal and the dataset summary above."}
41
-
42
  ]
43
 
44
  textgen_config.messages = messages
45
- completions: TextGenerationResponse = text_gen.generate(
46
  messages=messages, config=textgen_config)
47
- return [x['content'] for x in completions.text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ from lida.utils import clean_code_snippet
4
  from lida.modules.scaffold import ChartScaffold
5
  from llmx import TextGenerator, TextGenerationConfig, TextGenerationResponse
6
  # from lida.modules.scaffold import ChartScaffold
 
8
 
9
 
10
  system_prompt = """
11
+ You are a helpful assistant highly skilled in recommending a DIVERSE set of visualizations as code. Your input is an example visualization code, a summary of a dataset and an example visualization goal. Given this input, your task is to recommend an additional DIVERSE visualizations that a user may be interesting to a user. Consider different types of valid aggregations, chart types, and use different variables from the data summary. THE CODE YOU GENERATE MUST BE CORRECT AND FOLLOW VISUALIZATION BEST PRACTICES.
12
+
13
+ Your output MUST be perfect JSON in THE FORM OF A VALID JSON LIST without any additional explanation e.g.,
14
+
15
+ [{"code": "import ...", "index":0}, .. {"code": "import ...", "index":1} ]
16
  """
17
 
18
+ # refactor this to return n predictions ...
19
+ logger = logging.getLogger(__name__)
20
+
21
 
22
  class VizRecommender(object):
23
  """Generate visualizations from prompt"""
 
31
  self, code: str, summary: Summary,
32
  textgen_config: TextGenerationConfig,
33
  text_gen: TextGenerator,
34
+ n=3,
35
  library='altair'):
36
  """Recommend a code spec based on existing visualization"""
37
 
 
46
  {"role": "system", "content": f"The dataset summary is : {summary}"},
47
  {"role": "system",
48
  "content":
49
+ f"An example visualization code is: {code}. You MUST use only the {library} library with the following instructions {library_instructions}. Each recommended visualization CODE MUST use the following template {library_template}."},
50
+ {"role": "user", "content": f"Now write code for {n} visualizations in the JSON list format. YOU MUST RETURN ONLY A JSON LIST"}
 
51
  ]
52
 
53
  textgen_config.messages = messages
54
+ result: TextGenerationResponse = text_gen.generate(
55
  messages=messages, config=textgen_config)
56
+ try:
57
+ json_string = clean_code_snippet(result.text[0]["content"])
58
+ result = json.loads(json_string)
59
+ if isinstance(result, dict):
60
+ result = [result]
61
+ result = [x["code"] for x in result]
62
+ except json.decoder.JSONDecodeError:
63
+ logger.info(
64
+ f"Error decoding JSON for generated visualization recommendations: {result.text[0]['content']}")
65
+ print(
66
+ f"Error decoding JSON for generated visualization recommendations: {result.text[0]['content']}")
67
+ raise ValueError(
68
+ "The model did not return a valid JSON object while attempting generate visualization recommendations. Please try again.")
69
+ return result
notebooks/tutorial.ipynb CHANGED
The diff for this file is too large to render. See raw diff