Spaces:
Sleeping
Sleeping
dolphinium
commited on
Commit
·
8290c25
1
Parent(s):
840c57d
enhance viz code generation prompt
Browse files
app.py
CHANGED
|
@@ -328,33 +328,174 @@ def llm_generate_visualization_code(query_context, facet_data):
|
|
| 328 |
"""Generates Python code for visualization based on query and data."""
|
| 329 |
prompt = f"""
|
| 330 |
You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
|
| 331 |
-
Your task is to generate Python code to create a single, insightful visualization.
|
| 332 |
|
| 333 |
-
**
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
"""
|
| 352 |
try:
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
| 354 |
code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE)
|
| 355 |
return code
|
| 356 |
except Exception as e:
|
| 357 |
-
print(f"Error in llm_generate_visualization_code: {e}")
|
| 358 |
return None
|
| 359 |
|
| 360 |
def execute_viz_code_and_get_path(viz_code, facet_data):
|
|
|
|
| 328 |
"""Generates Python code for visualization based on query and data."""
|
| 329 |
prompt = f"""
|
| 330 |
You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
|
| 331 |
+
Your task is to generate robust, error-free Python code to create a single, insightful visualization based on the user's query and the provided Solr facet data.
|
| 332 |
|
| 333 |
+
**User's Analytical Goal:**
|
| 334 |
+
"{query_context}"
|
| 335 |
+
|
| 336 |
+
**Aggregated Data (from Solr Facets):**
|
| 337 |
+
```json
|
| 338 |
+
{json.dumps(facet_data, indent=2)}
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
---
|
| 342 |
+
### **CRITICAL INSTRUCTIONS: CODE GENERATION RULES**
|
| 343 |
+
You MUST follow these rules to avoid errors.
|
| 344 |
+
|
| 345 |
+
**1. Identify the Data Structure FIRST:**
|
| 346 |
+
Before writing any code, analyze the `facet_data` JSON to determine its structure. There are three common patterns. Choose the correct template below.
|
| 347 |
+
|
| 348 |
+
* **Pattern A: Simple `terms` Facet.** The JSON has ONE main key (besides "count") which contains a list of "buckets". Each bucket has a "val" and a "count". Use this for standard bar charts.
|
| 349 |
+
* **Pattern B: Multiple `query` Facets.** The JSON has MULTIPLE keys (besides "count"), and each key is an object containing metrics like "count" or "sum(...)". Use this for comparing a few distinct items (e.g., "oral vs injection").
|
| 350 |
+
* **Pattern C: Nested `terms` Facet.** The JSON has one main key with a list of "buckets", but inside EACH bucket, there are nested metric objects. This is used for grouped comparisons (e.g., "compare 2024 vs 2025 across categories"). This almost always requires `pandas`.
|
| 351 |
+
|
| 352 |
+
**2. Use the Correct Parsing Template:**
|
| 353 |
+
|
| 354 |
+
---
|
| 355 |
+
**TEMPLATE FOR PATTERN A (Simple Bar Chart from `terms` facet):**
|
| 356 |
+
```python
|
| 357 |
+
import matplotlib.pyplot as plt
|
| 358 |
+
import seaborn as sns
|
| 359 |
+
import pandas as pd
|
| 360 |
+
|
| 361 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 362 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 363 |
+
|
| 364 |
+
# Dynamically find the main facet key (the one with 'buckets')
|
| 365 |
+
facet_key = None
|
| 366 |
+
for key, value in facet_data.items():
|
| 367 |
+
if isinstance(value, dict) and 'buckets' in value:
|
| 368 |
+
facet_key = key
|
| 369 |
+
break
|
| 370 |
+
|
| 371 |
+
if facet_key:
|
| 372 |
+
buckets = facet_data[facet_key].get('buckets', [])
|
| 373 |
+
# Check if buckets contain data
|
| 374 |
+
if buckets:
|
| 375 |
+
df = pd.DataFrame(buckets)
|
| 376 |
+
# Check for a nested metric or use 'count'
|
| 377 |
+
if 'total_deal_value' in df.columns and pd.api.types.is_dict_like(df['total_deal_value'].iloc):
|
| 378 |
+
# Example for nested sum metric
|
| 379 |
+
df['value'] = df['total_deal_value'].apply(lambda x: x.get('sum', 0))
|
| 380 |
+
y_axis_label = 'Sum of Total Deal Value'
|
| 381 |
+
else:
|
| 382 |
+
df.rename(columns={{'count': 'value'}}, inplace=True)
|
| 383 |
+
y_axis_label = 'Count'
|
| 384 |
+
|
| 385 |
+
sns.barplot(data=df, x='val', y='value', ax=ax, palette='viridis')
|
| 386 |
+
ax.set_xlabel('Category')
|
| 387 |
+
ax.set_ylabel(y_axis_label)
|
| 388 |
+
else:
|
| 389 |
+
ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
ax.set_title('Your Insightful Title Here')
|
| 393 |
+
# Correct way to rotate labels to prevent errors
|
| 394 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
| 395 |
+
plt.tight_layout()
|
| 396 |
+
```
|
| 397 |
+
---
|
| 398 |
+
**TEMPLATE FOR PATTERN B (Comparison Bar Chart from `query` facets):**
|
| 399 |
+
```python
|
| 400 |
+
import matplotlib.pyplot as plt
|
| 401 |
+
import seaborn as sns
|
| 402 |
+
import pandas as pd
|
| 403 |
+
|
| 404 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 405 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 406 |
+
|
| 407 |
+
labels = []
|
| 408 |
+
values = []
|
| 409 |
+
# Iterate through top-level keys, skipping the 'count'
|
| 410 |
+
for key, data_dict in facet_data.items():
|
| 411 |
+
if key == 'count' or not isinstance(data_dict, dict):
|
| 412 |
+
continue
|
| 413 |
+
# Extract the label (e.g., 'oral_deals' -> 'Oral')
|
| 414 |
+
label = key.replace('_deals', '').replace('_', ' ').title()
|
| 415 |
+
# Find the metric value, which is NOT 'count'
|
| 416 |
+
metric_value = 0
|
| 417 |
+
for sub_key, sub_value in data_dict.items():
|
| 418 |
+
if sub_key != 'count':
|
| 419 |
+
metric_value = sub_value
|
| 420 |
+
break # Found the metric
|
| 421 |
+
labels.append(label)
|
| 422 |
+
values.append(metric_value)
|
| 423 |
+
|
| 424 |
+
if labels:
|
| 425 |
+
sns.barplot(x=labels, y=values, ax=ax, palette='mako')
|
| 426 |
+
ax.set_ylabel('Total Deal Value') # Or other metric name
|
| 427 |
+
ax.set_xlabel('Category')
|
| 428 |
+
else:
|
| 429 |
+
ax.text(0.5, 0.5, 'No query facet data to plot.', ha='center')
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
ax.set_title('Your Insightful Title Here')
|
| 433 |
+
plt.tight_layout()
|
| 434 |
+
```
|
| 435 |
+
---
|
| 436 |
+
**TEMPLATE FOR PATTERN C (Grouped Bar Chart from nested `terms` facet):**
|
| 437 |
+
```python
|
| 438 |
+
import matplotlib.pyplot as plt
|
| 439 |
+
import seaborn as sns
|
| 440 |
+
import pandas as pd
|
| 441 |
|
| 442 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 443 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
| 444 |
+
|
| 445 |
+
# Find the key that has the buckets
|
| 446 |
+
facet_key = None
|
| 447 |
+
for key, value in facet_data.items():
|
| 448 |
+
if isinstance(value, dict) and 'buckets' in value:
|
| 449 |
+
facet_key = key
|
| 450 |
+
break
|
| 451 |
+
|
| 452 |
+
if facet_key and facet_data[facet_key].get('buckets'):
|
| 453 |
+
# This list comprehension is robust for parsing nested metrics
|
| 454 |
+
plot_data = []
|
| 455 |
+
for bucket in facet_data[facet_key]['buckets']:
|
| 456 |
+
category = bucket['val']
|
| 457 |
+
# Find all nested metrics (e.g., total_deal_value_2025)
|
| 458 |
+
for sub_key, sub_value in bucket.items():
|
| 459 |
+
if isinstance(sub_value, dict) and 'sum' in sub_value:
|
| 460 |
+
# Extracts year from 'total_deal_value_2025' -> '2025'
|
| 461 |
+
year = sub_key.split('_')[-1]
|
| 462 |
+
value = sub_value['sum']
|
| 463 |
+
plot_data.append({{'Category': category, 'Year': year, 'Value': value}})
|
| 464 |
+
|
| 465 |
+
if plot_data:
|
| 466 |
+
df = pd.DataFrame(plot_data)
|
| 467 |
+
sns.barplot(data=df, x='Category', y='Value', hue='Year', ax=ax)
|
| 468 |
+
ax.set_ylabel('Total Deal Value')
|
| 469 |
+
ax.set_xlabel('Business Model')
|
| 470 |
+
# Correct way to rotate labels to prevent errors
|
| 471 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
| 472 |
+
else:
|
| 473 |
+
ax.text(0.5, 0.5, 'No nested data found to plot.', ha='center')
|
| 474 |
+
else:
|
| 475 |
+
ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
|
| 476 |
+
|
| 477 |
+
ax.set_title('Your Insightful Title Here')
|
| 478 |
+
plt.tight_layout()
|
| 479 |
+
```
|
| 480 |
+
---
|
| 481 |
+
**3. Final Code Generation:**
|
| 482 |
+
- **DO NOT** include `plt.show()`.
|
| 483 |
+
- **DO** set a dynamic and descriptive `ax.set_title()`, `ax.set_xlabel()`, and `ax.set_ylabel()`.
|
| 484 |
+
- **DO NOT** wrap the code in ```python ... ```. Output only the raw Python code.
|
| 485 |
+
- Adapt the chosen template to the specific keys and metrics in the provided `facet_data`.
|
| 486 |
+
|
| 487 |
+
**Your Task:**
|
| 488 |
+
Now, generate the Python code.
|
| 489 |
"""
|
| 490 |
try:
|
| 491 |
+
# Increase the timeout for potentially complex generation
|
| 492 |
+
generation_config = genai.types.GenerationConfig(temperature=0, max_output_tokens=2048)
|
| 493 |
+
response = llm_model.generate_content(prompt, generation_config=generation_config)
|
| 494 |
+
# Clean the response to remove markdown formatting
|
| 495 |
code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE)
|
| 496 |
return code
|
| 497 |
except Exception as e:
|
| 498 |
+
print(f"Error in llm_generate_visualization_code: {e}\nRaw response: {response.text}")
|
| 499 |
return None
|
| 500 |
|
| 501 |
def execute_viz_code_and_get_path(viz_code, facet_data):
|