rairo commited on
Commit
eeda2ec
·
verified ·
1 Parent(s): 38ecc79

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +4 -3
sozo_gen.py CHANGED
@@ -106,9 +106,10 @@ class ChartGenerator:
106
  self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
107
 
108
  def generate_chart_spec(self, description: str) -> ChartSpecification:
 
109
  spec_prompt = f"""
110
  You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
111
- **Dataset Info:** {json.dumps(self.enhanced_ctx, indent=2)}
112
  **Chart Request:** {description}
113
  **Return a JSON specification with these exact fields:**
114
  {{
@@ -224,7 +225,7 @@ def animate_image_fade(img: np.ndarray, dur: float, out: Path, fps: int = 24) ->
224
 
225
  def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
226
  try:
227
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
228
  chart_generator = ChartGenerator(llm, df)
229
  chart_spec = chart_generator.generate_chart_spec(desc)
230
  return animate_chart(chart_spec, df, dur, out)
@@ -232,7 +233,7 @@ def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
232
  logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
233
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
234
  temp_png = Path(temp_png_file.name)
235
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=API_KEY, temperature=0.1)
236
  chart_generator = ChartGenerator(llm, df)
237
  chart_spec = chart_generator.generate_chart_spec(desc)
238
  if execute_chart_spec(chart_spec, df, temp_png):
 
106
  self.enhanced_ctx = enhance_data_context(df, {"columns": list(df.columns), "shape": df.shape, "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}})
107
 
108
  def generate_chart_spec(self, description: str) -> ChartSpecification:
109
+ safe_ctx = json_serializable(self.enhanced_ctx)
110
  spec_prompt = f"""
111
  You are a data visualization expert. Based on the dataset and chart description, generate a precise chart specification.
112
+ **Dataset Info:** {json.dumps(safe_ctx, indent=2)}
113
  **Chart Request:** {description}
114
  **Return a JSON specification with these exact fields:**
115
  {{
 
225
 
226
  def safe_chart(desc: str, df: pd.DataFrame, dur: float, out: Path) -> str:
227
  try:
228
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
229
  chart_generator = ChartGenerator(llm, df)
230
  chart_spec = chart_generator.generate_chart_spec(desc)
231
  return animate_chart(chart_spec, df, dur, out)
 
233
  logging.error(f"Chart animation failed for '{desc}': {e}. Falling back to static image.")
234
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_png_file:
235
  temp_png = Path(temp_png_file.name)
236
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY, temperature=0.1)
237
  chart_generator = ChartGenerator(llm, df)
238
  chart_spec = chart_generator.generate_chart_spec(desc)
239
  if execute_chart_spec(chart_spec, df, temp_png):