wine_analysis / app.py
opti-alexc's picture
polish
fc5823f
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import gradio as gr
import ast
from functools import lru_cache
from collections import Counter
import requests
import os
import time
# --- Constants and Mappings (Unchanged) ---
BODY_ORDER = ['Very light-bodied', 'Light-bodied', 'Medium-bodied', 'Full-bodied', 'Very full-bodied']
ACIDITY_ORDER = ['Low', 'Medium', 'High']
BODY_MAPPING = {'Very light-bodied': 1, 'Light-bodied': 2, 'Medium-bodied': 3, 'Full-bodied': 4, 'Very full-bodied': 5}
WINE_TYPE_ORDER = {'Red': 2, 'Rosรฉ': 1, 'White': 0}
SAMPLE_THRESHOLDS = {
'Very Common (250+)': 250,
'Common (100+)': 100,
'Uncommon (50+)': 50,
'Rare (20+)': 20
}
COUNTRY_FLAGS = {
'United States': '๐Ÿ‡บ๐Ÿ‡ธ', 'France': '๐Ÿ‡ซ๐Ÿ‡ท', 'Italy': '๐Ÿ‡ฎ๐Ÿ‡น', 'Spain': '๐Ÿ‡ช๐Ÿ‡ธ', 'Germany': '๐Ÿ‡ฉ๐Ÿ‡ช', 'Australia': '๐Ÿ‡ฆ๐Ÿ‡บ',
'Chile': '๐Ÿ‡จ๐Ÿ‡ฑ', 'Argentina': '๐Ÿ‡ฆ๐Ÿ‡ท', 'Portugal': '๐Ÿ‡ต๐Ÿ‡น', 'South Africa': '๐Ÿ‡ฟ๐Ÿ‡ฆ', 'New Zealand': '๐Ÿ‡ณ๐Ÿ‡ฟ',
'Austria': '๐Ÿ‡ฆ๐Ÿ‡น', 'Greece': '๐Ÿ‡ฌ๐Ÿ‡ท', 'Hungary': '๐Ÿ‡ญ๐Ÿ‡บ', 'Croatia': '๐Ÿ‡ญ๐Ÿ‡ท', 'Slovenia': '๐Ÿ‡ธ๐Ÿ‡ฎ', 'Canada': '๐Ÿ‡จ๐Ÿ‡ฆ',
'Brazil': '๐Ÿ‡ง๐Ÿ‡ท', 'Uruguay': '๐Ÿ‡บ๐Ÿ‡พ', 'Israel': '๐Ÿ‡ฎ๐Ÿ‡ฑ', 'Lebanon': '๐Ÿ‡ฑ๐Ÿ‡ง', 'Turkey': '๐Ÿ‡น๐Ÿ‡ท', 'Bulgaria': '๐Ÿ‡ง๐Ÿ‡ฌ',
'Romania': '๐Ÿ‡ท๐Ÿ‡ด', 'Georgia': '๐Ÿ‡ฌ๐Ÿ‡ช', 'Moldova': '๐Ÿ‡ฒ๐Ÿ‡ฉ', 'Switzerland': '๐Ÿ‡จ๐Ÿ‡ญ', 'England': '๐Ÿด'
}
FOOD_EMOJIS = {
# Meat - Specific Animals
'Beef': '๐Ÿ„', 'Pork': '๐Ÿท', 'Lamb': '๐Ÿ‘', 'Veal': '๐Ÿ„', 'Ham': '๐Ÿท',
'Poultry': '๐Ÿ”', 'Chicken': '๐Ÿ”', 'Duck': '๐Ÿฆ†',
'Game Meat': '๐ŸฆŒ', 'Meat': '๐Ÿฅฉ',
# Cured/Processed Meats
'Cured Meat': '๐Ÿฅ“', 'Cold Cuts': '๐Ÿฅช', 'Barbecue': '๐Ÿ”ฅ', 'Grilled': '๐Ÿ”ฅ', 'Roast': '๐Ÿ–',
# Fish & Seafood - Specific Types
'Rich Fish': '๐ŸŸ', 'Lean Fish': '๐ŸŸ', 'Fish': '๐ŸŸ', 'Codfish': '๐ŸŸ',
'Shellfish': '๐Ÿฆ', 'Seafood': '๐Ÿฆž', 'Sushi': '๐Ÿฃ', 'Sashimi': '๐Ÿฃ',
# Cheese - Different Types
'Cheese': '๐Ÿง€', 'Soft Cheese': '๐Ÿง€', 'Hard Cheese': '๐Ÿง€', 'Blue Cheese': '๐Ÿง€',
'Goat Cheese': '๐Ÿ', 'Maturated Cheese': '๐Ÿง€', 'Mild Cheese': '๐Ÿง€', 'Medium-cured Cheese': '๐Ÿง€',
# Pasta & Italian
'Pasta': '๐Ÿ', 'Tagliatelle': '๐Ÿ', 'Lasagna': '๐Ÿ', 'Risotto': '๐Ÿš',
'Pizza': '๐Ÿ•', 'Eggplant Parmigiana': '๐Ÿ†',
# Asian Food
'Asian Food': '๐Ÿฅข', 'Curry Chicken': '๐Ÿ›', 'Yakissoba': '๐Ÿœ', 'Paella': '๐Ÿฅ˜',
# Vegetables & Vegetarian
'Vegetarian': '๐Ÿฅฌ', 'Salad': '๐Ÿฅ—', 'Mushrooms': '๐Ÿ„', 'Beans': '๐Ÿซ˜',
'Baked Potato': '๐Ÿฅ”', 'French Fries': '๐ŸŸ',
# Desserts & Sweets
'Sweet Dessert': '๐Ÿฐ', 'Dessert': '๐Ÿฐ', 'Fruit Dessert': '๐Ÿ“',
'Cake': '๐Ÿฐ', 'Cookies': '๐Ÿช', 'Chocolate': '๐Ÿซ', 'Cream': '๐Ÿจ',
'Citric Dessert': '๐Ÿ‹', 'Spiced Fruit Cake': '๐Ÿฐ',
# Fruits & Nuts
'Fruit': '๐Ÿ‡', 'Dried Fruits': '๐Ÿ‡', 'Chestnut': '๐ŸŒฐ',
# Appetizers & Snacks
'Appetizer': '๐Ÿค', 'Snack': '๐Ÿฅจ', 'Aperitif': '๐Ÿฅ‚',
# Soups & Stews
'Light Stews': '๐Ÿฒ', 'Soufflรฉ': '๐Ÿฅ„',
# Dishes & Preparations
'Spicy Food': '๐ŸŒถ๏ธ', 'Tomato Dishes': '๐Ÿ…'
}
# --- Data Download Function ---
def download_data():
"""Downloads the dataset from Google Drive if not already present."""
csv_filename = 'XWines_Full_100K_wines.csv'
if os.path.exists(csv_filename):
print(f"Using existing dataset: {csv_filename}")
return csv_filename
# Convert Google Drive share link to direct download link
file_id = '1uEEipmKNxdiKUAhjH-K14JOSQ2BLRFss'
download_url = f'https://drive.google.com/uc?export=download&id={file_id}'
print(f"Downloading dataset from Google Drive...")
try:
response = requests.get(download_url, stream=True)
response.raise_for_status()
with open(csv_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Dataset downloaded successfully: {csv_filename}")
return csv_filename
except Exception as e:
raise Exception(f"Failed to download dataset: {str(e)}")
# --- OPTIMIZATION 1: Data Loading & Pre-processing ---
@lru_cache(maxsize=1)
def load_and_preprocess_data():
"""Loads and performs expensive one-time preprocessing on the dataset."""
start_time = time.time()
print("[TIMING] Starting data loading and preprocessing...")
csv_filename = download_data()
print(f"[TIMING] File check/download completed in {time.time() - start_time:.2f}s")
try:
csv_start = time.time()
print("[TIMING] Loading CSV data...")
# Use efficient data types and only load needed columns if possible
df = pd.read_csv(csv_filename, low_memory=False)
print(f"[TIMING] CSV loaded in {time.time() - csv_start:.2f}s - {len(df):,} wine records")
except FileNotFoundError:
raise FileNotFoundError(f"CSV file '{csv_filename}' not found.")
def parse_list_string(s):
if not isinstance(s, str) or not s.strip():
return []
try:
# Fast path for common patterns
s = s.strip()
if s.startswith('[') and s.endswith(']'):
return ast.literal_eval(s)
return []
except (ValueError, SyntaxError):
return []
# Vectorized string processing for better performance
parse_start = time.time()
print("[TIMING] Starting string parsing...")
df['grapes_list'] = df['Grapes'].fillna('[]').apply(parse_list_string)
print(f"[TIMING] Grapes parsing completed in {time.time() - parse_start:.2f}s")
harmonize_start = time.time()
df['harmonize_list'] = df['Harmonize'].fillna('[]').apply(parse_list_string)
print(f"[TIMING] Harmonize parsing completed in {time.time() - harmonize_start:.2f}s")
derived_start = time.time()
df['main_grape'] = df['grapes_list'].apply(lambda x: x[0] if x else 'Unknown')
df['num_grapes'] = df['grapes_list'].apply(len)
df['body_numeric'] = df['Body'].map(BODY_MAPPING)
print(f"[TIMING] Derived columns completed in {time.time() - derived_start:.2f}s")
total_time = time.time() - start_time
print(f"[TIMING] Total preprocessing completed in {total_time:.2f}s")
return df
# --- OPTIMIZATION 2: Vectorized Data Aggregation ---
def get_top_food_pairings(harmonize_list, top_n=3):
"""Get top N food pairings with emojis and names - optimized version."""
# Flatten list more efficiently
all_pairings = []
for sublist in harmonize_list:
if isinstance(sublist, list):
all_pairings.extend(sublist)
if not all_pairings:
return {'emojis': '๐Ÿฝ๏ธ', 'names': 'General'}
top_items = Counter(all_pairings).most_common(top_n)
emojis = ''.join([FOOD_EMOJIS.get(item[0], '๐Ÿฝ๏ธ') for item in top_items])
names = ', '.join([item[0] for item in top_items])
return {'emojis': emojis, 'names': names}
def aggregate_wine_data(df, wine_types, max_grape_count, min_samples_choice, regional_grouping):
"""Filters and aggregates wine data using efficient, vectorized pandas operations."""
agg_start = time.time()
print(f"[TIMING] Starting aggregation with {len(df):,} records...")
filtered_df = df.copy()
if wine_types and 'All' not in wine_types:
filtered_df = filtered_df[filtered_df['Type'].isin(wine_types)]
if max_grape_count < 5:
filtered_df = filtered_df[filtered_df['num_grapes'] <= max_grape_count]
group_by_cols = ['main_grape', 'Type']
if regional_grouping:
group_by_cols.append('Country')
groupby_start = time.time()
agg_df = filtered_df.groupby(group_by_cols).agg(
count=('ABV', 'size'),
avg_fullness=('body_numeric', 'mean'),
abv_list=('ABV', list),
body_list=('Body', list),
acidity_list=('Acidity', list),
harmonize_list=('harmonize_list', list),
region_count=('RegionName', 'nunique'),
winery_count=('WineryName', 'nunique')
).reset_index()
print(f"[TIMING] GroupBy aggregation completed in {time.time() - groupby_start:.2f}s")
min_samples = SAMPLE_THRESHOLDS[min_samples_choice]
agg_df = agg_df[agg_df['count'] >= min_samples].copy()
if agg_df.empty:
return agg_df
# Optimized distribution calculation
def calc_distribution(values_list, categories):
if not values_list:
return {cat: 0.0 for cat in categories}
counts = pd.Series(values_list).value_counts(normalize=True) * 100
return {cat: counts.get(cat, 0.0) for cat in categories}
dist_start = time.time()
agg_df['body_dist'] = agg_df['body_list'].apply(
lambda x: calc_distribution(x, BODY_ORDER))
agg_df['acid_dist'] = agg_df['acidity_list'].apply(
lambda x: calc_distribution(x, ACIDITY_ORDER))
print(f"[TIMING] Distribution calculations completed in {time.time() - dist_start:.2f}s")
# Pre-compute food pairings more efficiently
pairing_start = time.time()
pairing_data = []
for harmonize_list in agg_df['harmonize_list']:
pairing_data.append(get_top_food_pairings(harmonize_list))
agg_df['pairing_data'] = pairing_data
agg_df['pairing_emoji'] = agg_df['pairing_data'].apply(lambda x: x['emojis'])
agg_df['pairing_names'] = agg_df['pairing_data'].apply(lambda x: x['names'])
print(f"[TIMING] Food pairing calculations completed in {time.time() - pairing_start:.2f}s")
final_start = time.time()
agg_df['wine_type_order'] = agg_df['Type'].map(WINE_TYPE_ORDER)
agg_df = agg_df.sort_values(by=['wine_type_order', 'avg_fullness'], ascending=[False, True])
print(f"[TIMING] Final sorting completed in {time.time() - final_start:.2f}s")
total_agg_time = time.time() - agg_start
print(f"[TIMING] Total aggregation completed in {total_agg_time:.2f}s - {len(agg_df)} combinations")
return agg_df
# --- OPTIMIZATION 3: Efficient & Clean Chart Creation ---
def create_wine_chart(chart_data, regional_grouping):
"""Creates the Plotly figure with optimized traces and layout."""
chart_start = time.time()
print(f"[TIMING] Starting chart creation with {len(chart_data)} rows...")
if chart_data.empty:
fig = go.Figure()
fig.add_annotation(text="No data available with current filters.", xref="paper", yref="paper", x=0.5, y=0.5,
showarrow=False)
return fig
num_rows = len(chart_data)
# Pre-compute labels more efficiently
prep_start = time.time()
wine_type_emojis = {'Red': '๐Ÿท', 'White': '๐Ÿฅ‚', 'Rosรฉ': '๐ŸŒธ', 'Sparkling': '๐Ÿพ'}
chart_data = chart_data.copy() # Avoid SettingWithCopyWarning
chart_data['wine_emoji'] = chart_data['Type'].map(wine_type_emojis).fillna('๐Ÿท')
if regional_grouping:
chart_data['flag'] = chart_data['Country'].map(COUNTRY_FLAGS).fillna('๐ŸŒ')
# Vectorized string concatenation instead of apply
chart_data['grape_label'] = chart_data['wine_emoji'] + ' ' + chart_data['main_grape'] + ' ' + chart_data['flag']
else:
chart_data['grape_label'] = chart_data['wine_emoji'] + ' ' + chart_data['main_grape']
print(f"[TIMING] Label preparation completed in {time.time() - prep_start:.2f}s")
y_labels = chart_data['grape_label'].tolist()
fig = make_subplots(
rows=1, cols=5,
specs=[[{}, {"type": "bar"}, {"type": "bar"}, {"type": "box"}, {}]],
column_widths=[0.25, 0.22, 0.22, 0.16, 0.05],
horizontal_spacing=0.025,
shared_yaxes=True
)
# Pre-compute hover texts more efficiently
hover_start = time.time()
country_part = chart_data['Country'] if regional_grouping else 'Global'
hover_texts = (
'<b>' + chart_data['main_grape'] + ' (' + country_part.astype(str) + ')</b><br>' +
'Wineries: ' + chart_data['winery_count'].astype(str) + '<br>' +
'Regions: ' + chart_data['region_count'].astype(str) + '<br>' +
'Total Wines: ' + chart_data['count'].apply(lambda x: f"{x:,}")
).tolist()
print(f"[TIMING] Hover text preparation completed in {time.time() - hover_start:.2f}s")
fig.add_trace(go.Bar(
y=y_labels, x=[1] * num_rows, orientation='h',
marker_color='rgba(0,0,0,0)', showlegend=False,
hoverinfo='text', hovertext=hover_texts
), row=1, col=1)
fig.add_trace(go.Scatter(
y=y_labels, x=[0.03] * num_rows, mode='text',
text=y_labels, textposition='middle right',
textfont={'size': 22, 'color': '#1A1A1A'},
hoverinfo='none', showlegend=False
), row=1, col=1)
# Optimize body distribution traces
body_start = time.time()
body_colors = {'Very light-bodied': '#FFB6C1', 'Light-bodied': '#CD5C5C', 'Medium-bodied': '#C13636',
'Full-bodied': '#8B0000', 'Very full-bodied': '#4B0000'}
# Pre-extract all body values at once
body_data = {}
for body_type in BODY_ORDER:
body_data[body_type] = [d.get(body_type, 0) for d in chart_data['body_dist']]
for body_type in BODY_ORDER:
fig.add_trace(go.Bar(
y=y_labels, x=body_data[body_type], name=body_type, orientation='h',
marker_color=body_colors.get(body_type), showlegend=False,
hovertemplate=f"{body_type}: %{{x:.1f}}%<extra></extra>"
), row=1, col=2)
print(f"[TIMING] Body traces completed in {time.time() - body_start:.2f}s")
# Optimize acidity distribution traces
acid_start = time.time()
acid_colors = {'Low': '#F5F5DC', 'Medium': '#DAA520', 'High': '#B8860B'}
# Pre-extract all acidity values at once
acid_data = {}
for acid_type in ACIDITY_ORDER:
acid_data[acid_type] = [d.get(acid_type, 0) for d in chart_data['acid_dist']]
for acid_type in ACIDITY_ORDER:
fig.add_trace(go.Bar(
y=y_labels, x=acid_data[acid_type], name=acid_type, orientation='h',
marker_color=acid_colors.get(acid_type), showlegend=False,
hovertemplate=f"{acid_type} acidity: %{{x:.1f}}%<extra></extra>"
), row=1, col=3)
print(f"[TIMING] Acidity traces completed in {time.time() - acid_start:.2f}s")
# Optimize box plot creation
box_start = time.time()
box_colors = {'Red': '#8B0000', 'White': '#DAA520', 'Rosรฉ': '#CD5C5C', 'Sparkling': '#9370DB'}
# Create box plots more efficiently
for idx, (_, row) in enumerate(chart_data.iterrows()):
abv_values = row['abv_list']
color = box_colors.get(row['Type'], '#6A5ACD')
fig.add_trace(go.Box(
y=[y_labels[idx]] * len(abv_values), x=abv_values, name=row['Type'], orientation='h',
showlegend=False, marker_color=color, line_color=color,
hovertemplate=f"ABV: %{{x:.1f}}%<extra></extra>"
), row=1, col=4)
print(f"[TIMING] Box plot traces completed in {time.time() - box_start:.2f}s")
fig.add_trace(go.Scatter(
y=y_labels, x=[0.5] * num_rows, mode='text',
text=chart_data['pairing_emoji'], textposition='middle center',
textfont={'size': 32}, showlegend=False,
hoverinfo='text', hovertext=chart_data['pairing_names']
), row=1, col=5)
fig.update_layout(
title={
'text': "Wine Characteristics by Grape Variety",
'x': 0.5,
'font': {'size': 26, 'color': '#1A1A1A'}
},
height=max(600, num_rows * 55),
barmode='stack', showlegend=False, plot_bgcolor='#FAFAFA', paper_bgcolor='#E8E9EA',
margin=dict(l=30, r=30, t=100, b=60), boxgap=0.3, bargap=0.2
)
column_titles = ["Grape Variety", "Body Profile (%)", "Acidity Profile (%)", "Alcohol (ABV %)", "Food Pairing"]
# Add column titles
title_start = time.time()
for i, title in enumerate(column_titles, 1):
domain = fig.layout[f'xaxis{i if i > 1 else ""}'].domain
fig.add_annotation(
x=(domain[0] + domain[1]) / 2, y=1.02,
xref="paper", yref="paper", text=f"<b>{title}</b>",
xanchor='center', showarrow=False, font={'size': 20, 'color': '#1A1A1A'}
)
print(f"[TIMING] Column titles completed in {time.time() - title_start:.2f}s")
# Update axes formatting
axes_start = time.time()
for i in range(1, 6):
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=i)
# Show x-axis labels only for body, acidity, and ABV columns
if i in [2, 3, 4]:
fig.update_xaxes(showticklabels=True, showgrid=True, gridcolor='rgba(0,0,0,0.1)',
zeroline=False, title_text="", row=1, col=i)
else:
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, title_text="", row=1, col=i)
print(f"[TIMING] Axes formatting completed in {time.time() - axes_start:.2f}s")
# Final axis configuration
final_axes_start = time.time()
fig.update_yaxes(categoryorder="array", categoryarray=y_labels, autorange=False, range=[-0.5, num_rows - 0.5],
row=1, col=1)
print(f"[TIMING] Final axis configuration completed in {time.time() - final_axes_start:.2f}s")
# Add alternating row backgrounds for better readability
bg_start = time.time()
for i in range(num_rows):
if i % 2 == 1:
fig.add_hrect(y0=i - 0.5, y1=i + 0.5, fillcolor="#F0F2F3", layer="below", line_width=0, row=1, col="all")
print(f"[TIMING] Background rectangles completed in {time.time() - bg_start:.2f}s")
chart_time = time.time() - chart_start
print(f"[TIMING] Chart creation completed in {chart_time:.2f}s")
return fig
# --- Gradio Interface Logic ---
def update_dashboard(wine_types, max_grape_count, min_samples_choice, regional_grouping,
progress=gr.Progress(track_tqdm=True)):
"""Main function to update dashboard."""
dashboard_start = time.time()
print(f"[TIMING] Starting dashboard update...")
progress(0, desc="Loading and processing data...")
df = load_and_preprocess_data()
progress(0.5, desc="Filtering and aggregating...")
chart_data = aggregate_wine_data(df, wine_types, max_grape_count, min_samples_choice, regional_grouping)
progress(0.8, desc="Creating chart...")
fig = create_wine_chart(chart_data, regional_grouping)
total_combinations = len(chart_data)
total_wines = chart_data['count'].sum() if not chart_data.empty else 0
min_samples = SAMPLE_THRESHOLDS[min_samples_choice]
grouping_type = "grape+region" if regional_grouping else "grape+type"
summary = f"๐Ÿ“Š Showing **{total_combinations}** {grouping_type} combinations from **{total_wines:,}** wines (min {min_samples} samples each)"
total_dashboard_time = time.time() - dashboard_start
print(f"[TIMING] Total dashboard update completed in {total_dashboard_time:.2f}s")
return fig, summary
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Wine Analysis Dashboard", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿท Wine Characteristics Dashboard")
with gr.Row():
wine_type_filter = gr.CheckboxGroup(
choices=['Red', 'White', 'Rosรฉ', 'Sparkling', 'Dessert', 'Dessert/Port'],
value=['Red'], label="๐Ÿท Wine Types"
)
max_grape_slider = gr.Slider(
minimum=1, maximum=5, step=1, value=1, label="๐Ÿ‡ Max Grapes per Wine",
info="1: Varietals, 5: All Blends"
)
min_samples_choice = gr.Radio(
choices=list(SAMPLE_THRESHOLDS.keys()), value='Very Common (250+)',
label="Minimum Sample Size (Wines per Variety)",
)
regional_grouping = gr.Checkbox(
value=True, label="Split By Country"
)
summary_text = gr.Markdown()
wine_plot = gr.Plot()
inputs = [wine_type_filter, max_grape_slider, min_samples_choice, regional_grouping]
outputs = [wine_plot, summary_text]
# Auto-update when any input changes
for input_component in inputs:
input_component.change(update_dashboard, inputs=inputs, outputs=outputs)
demo.load(fn=update_dashboard, inputs=inputs, outputs=outputs)
return demo
if __name__ == "__main__":
app_interface = create_interface()
app_interface.launch()