Victor Dibia commited on
Commit ·
61358ce
1
Parent(s): 8d3cc3c
recommender update
Browse files- MANIFEST.in +3 -1
- lida/modules/executor.py +2 -0
- lida/modules/manager.py +2 -0
- lida/modules/scaffold.py +1 -1
- lida/modules/viz/vizrecommender.py +29 -7
- notebooks/tutorial.ipynb +0 -0
MANIFEST.in
CHANGED
|
@@ -1,2 +1,4 @@
|
|
| 1 |
recursive-include lida/web/ui *
|
| 2 |
-
recursive-exclude notebooks *
|
|
|
|
|
|
|
|
|
| 1 |
recursive-include lida/web/ui *
|
| 2 |
+
recursive-exclude notebooks *
|
| 3 |
+
recursive-exclude docs *
|
| 4 |
+
recursive-exclude tests *
|
lida/modules/executor.py
CHANGED
|
@@ -41,6 +41,8 @@ def preprocess_code(code: str) -> str:
|
|
| 41 |
code = code[index:]
|
| 42 |
|
| 43 |
code = code.replace("```", "")
|
|
|
|
|
|
|
| 44 |
return code
|
| 45 |
|
| 46 |
|
|
|
|
| 41 |
code = code[index:]
|
| 42 |
|
| 43 |
code = code.replace("```", "")
|
| 44 |
+
if "chart = plot(data)" not in code:
|
| 45 |
+
code = code + "\nchart = plot(data)"
|
| 46 |
return code
|
| 47 |
|
| 48 |
|
lida/modules/manager.py
CHANGED
|
@@ -219,6 +219,7 @@ class Manager(object):
|
|
| 219 |
self,
|
| 220 |
code,
|
| 221 |
summary: Summary,
|
|
|
|
| 222 |
textgen_config: TextGenerationConfig = TextGenerationConfig(),
|
| 223 |
library: str = "seaborn",
|
| 224 |
):
|
|
@@ -237,6 +238,7 @@ class Manager(object):
|
|
| 237 |
return self.recommender.generate(
|
| 238 |
code=code,
|
| 239 |
summary=summary,
|
|
|
|
| 240 |
textgen_config=textgen_config,
|
| 241 |
text_gen=self.text_gen,
|
| 242 |
library=library,
|
|
|
|
| 219 |
self,
|
| 220 |
code,
|
| 221 |
summary: Summary,
|
| 222 |
+
n=4,
|
| 223 |
textgen_config: TextGenerationConfig = TextGenerationConfig(),
|
| 224 |
library: str = "seaborn",
|
| 225 |
):
|
|
|
|
| 238 |
return self.recommender.generate(
|
| 239 |
code=code,
|
| 240 |
summary=summary,
|
| 241 |
+
n=n,
|
| 242 |
textgen_config=textgen_config,
|
| 243 |
text_gen=self.text_gen,
|
| 244 |
library=library,
|
lida/modules/scaffold.py
CHANGED
|
@@ -17,7 +17,7 @@ class ChartScaffold(object):
|
|
| 17 |
pass
|
| 18 |
|
| 19 |
def get_template(self, goal: Goal, library: str):
|
| 20 |
-
mpl_pre = f"Set chart title to {goal.question}. If the solution requires a single value (e.g. max, min, median, first, last etc), ALWAYS add a line (axvline or axhline) to the chart, ALWAYS with a legend containing the single value (formatted with 0.2F). If using a <field> where semantic_type=date, YOU MUST APPLY the following transform before using that column i) convert date fields to date types using data[''] = pd.to_datetime(data[<field>], errors='coerce'), ALWAYS use errors='coerce' ii) drop the rows with NaT values data = data[pd.notna(data[<field>])] iii) convert field to right time format for plotting. ALWAYS make sure the x-axis labels are legible (e.g., rotate when needed). Use BaseMap for charts that require a map. Given the dataset summary, the plot(data) method should generate a {library} chart ({goal.visualization}) that addresses this goal: {goal.question}. The plot method must return a matplotlib object. Think step by step. \n"
|
| 21 |
|
| 22 |
if library == "matplotlib":
|
| 23 |
instructions = {"role": "assistant", "content": mpl_pre}
|
|
|
|
| 17 |
pass
|
| 18 |
|
| 19 |
def get_template(self, goal: Goal, library: str):
|
| 20 |
+
mpl_pre = f"Set chart title to {goal.question}. If the solution requires a single value (e.g. max, min, median, first, last etc), ALWAYS add a line (axvline or axhline) to the chart, ALWAYS with a legend containing the single value (formatted with 0.2F). If using a <field> where semantic_type=date, YOU MUST APPLY the following transform before using that column i) convert date fields to date types using data[''] = pd.to_datetime(data[<field>], errors='coerce'), ALWAYS use errors='coerce' ii) drop the rows with NaT values data = data[pd.notna(data[<field>])] iii) convert field to right time format for plotting. ALWAYS make sure the x-axis labels are legible (e.g., rotate when needed). Use BaseMap for charts that require a map. Given the dataset summary, the plot(data) method should generate a {library} chart ({goal.visualization}) that addresses this goal: {goal.question}. Do not include plt.show(). The plot method must return a matplotlib object. Think step by step. \n"
|
| 21 |
|
| 22 |
if library == "matplotlib":
|
| 23 |
instructions = {"role": "assistant", "content": mpl_pre}
|
lida/modules/viz/vizrecommender.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
from lida.modules.scaffold import ChartScaffold
|
| 3 |
from llmx import TextGenerator, TextGenerationConfig, TextGenerationResponse
|
| 4 |
# from lida.modules.scaffold import ChartScaffold
|
|
@@ -6,9 +8,16 @@ from lida.datamodel import Goal, Summary
|
|
| 6 |
|
| 7 |
|
| 8 |
system_prompt = """
|
| 9 |
-
You are a helpful assistant highly skilled in recommending a DIVERSE set of visualizations. Your input is an
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class VizRecommender(object):
|
| 14 |
"""Generate visualizations from prompt"""
|
|
@@ -22,6 +31,7 @@ class VizRecommender(object):
|
|
| 22 |
self, code: str, summary: Summary,
|
| 23 |
textgen_config: TextGenerationConfig,
|
| 24 |
text_gen: TextGenerator,
|
|
|
|
| 25 |
library='altair'):
|
| 26 |
"""Recommend a code spec based on existing visualization"""
|
| 27 |
|
|
@@ -36,12 +46,24 @@ class VizRecommender(object):
|
|
| 36 |
{"role": "system", "content": f"The dataset summary is : {summary}"},
|
| 37 |
{"role": "system",
|
| 38 |
"content":
|
| 39 |
-
f"
|
| 40 |
-
{"role": "user", "content": "Now write code for
|
| 41 |
-
|
| 42 |
]
|
| 43 |
|
| 44 |
textgen_config.messages = messages
|
| 45 |
-
|
| 46 |
messages=messages, config=textgen_config)
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import json
|
| 3 |
+
from lida.utils import clean_code_snippet
|
| 4 |
from lida.modules.scaffold import ChartScaffold
|
| 5 |
from llmx import TextGenerator, TextGenerationConfig, TextGenerationResponse
|
| 6 |
# from lida.modules.scaffold import ChartScaffold
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
system_prompt = """
|
| 11 |
+
You are a helpful assistant highly skilled in recommending a DIVERSE set of visualizations as code. Your input is an example visualization code, a summary of a dataset and an example visualization goal. Given this input, your task is to recommend an additional DIVERSE visualizations that a user may be interesting to a user. Consider different types of valid aggregations, chart types, and use different variables from the data summary. THE CODE YOU GENERATE MUST BE CORRECT AND FOLLOW VISUALIZATION BEST PRACTICES.
|
| 12 |
+
|
| 13 |
+
Your output MUST be perfect JSON in THE FORM OF A VALID JSON LIST without any additional explanation e.g.,
|
| 14 |
+
|
| 15 |
+
[{"code": "import ...", "index":0}, .. {"code": "import ...", "index":1} ]
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
# refactor this to return n predictions ...
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
|
| 22 |
class VizRecommender(object):
|
| 23 |
"""Generate visualizations from prompt"""
|
|
|
|
| 31 |
self, code: str, summary: Summary,
|
| 32 |
textgen_config: TextGenerationConfig,
|
| 33 |
text_gen: TextGenerator,
|
| 34 |
+
n=3,
|
| 35 |
library='altair'):
|
| 36 |
"""Recommend a code spec based on existing visualization"""
|
| 37 |
|
|
|
|
| 46 |
{"role": "system", "content": f"The dataset summary is : {summary}"},
|
| 47 |
{"role": "system",
|
| 48 |
"content":
|
| 49 |
+
f"An example visualization code is: {code}. You MUST use only the {library} library with the following instructions {library_instructions}. Each recommended visualization CODE MUST use the following template {library_template}."},
|
| 50 |
+
{"role": "user", "content": f"Now write code for {n} visualizations in the JSON list format. YOU MUST RETURN ONLY A JSON LIST"}
|
|
|
|
| 51 |
]
|
| 52 |
|
| 53 |
textgen_config.messages = messages
|
| 54 |
+
result: TextGenerationResponse = text_gen.generate(
|
| 55 |
messages=messages, config=textgen_config)
|
| 56 |
+
try:
|
| 57 |
+
json_string = clean_code_snippet(result.text[0]["content"])
|
| 58 |
+
result = json.loads(json_string)
|
| 59 |
+
if isinstance(result, dict):
|
| 60 |
+
result = [result]
|
| 61 |
+
result = [x["code"] for x in result]
|
| 62 |
+
except json.decoder.JSONDecodeError:
|
| 63 |
+
logger.info(
|
| 64 |
+
f"Error decoding JSON for generated visualization recommendations: {result.text[0]['content']}")
|
| 65 |
+
print(
|
| 66 |
+
f"Error decoding JSON for generated visualization recommendations: {result.text[0]['content']}")
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"The model did not return a valid JSON object while attempting generate visualization recommendations. Please try again.")
|
| 69 |
+
return result
|
notebooks/tutorial.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|