import pandas as pd import plotly.graph_objects as go from plotly.subplots import make_subplots import gradio as gr import ast from functools import lru_cache from collections import Counter import requests import os import time # --- Constants and Mappings (Unchanged) --- BODY_ORDER = ['Very light-bodied', 'Light-bodied', 'Medium-bodied', 'Full-bodied', 'Very full-bodied'] ACIDITY_ORDER = ['Low', 'Medium', 'High'] BODY_MAPPING = {'Very light-bodied': 1, 'Light-bodied': 2, 'Medium-bodied': 3, 'Full-bodied': 4, 'Very full-bodied': 5} WINE_TYPE_ORDER = {'Red': 2, 'Rosé': 1, 'White': 0} SAMPLE_THRESHOLDS = { 'Very Common (250+)': 250, 'Common (100+)': 100, 'Uncommon (50+)': 50, 'Rare (20+)': 20 } COUNTRY_FLAGS = { 'United States': '🇺🇸', 'France': '🇫🇷', 'Italy': '🇮🇹', 'Spain': '🇪🇸', 'Germany': '🇩🇪', 'Australia': '🇦🇺', 'Chile': '🇨🇱', 'Argentina': '🇦🇷', 'Portugal': '🇵🇹', 'South Africa': '🇿🇦', 'New Zealand': '🇳🇿', 'Austria': '🇦🇹', 'Greece': '🇬🇷', 'Hungary': '🇭🇺', 'Croatia': '🇭🇷', 'Slovenia': '🇸🇮', 'Canada': '🇨🇦', 'Brazil': '🇧🇷', 'Uruguay': '🇺🇾', 'Israel': '🇮🇱', 'Lebanon': '🇱🇧', 'Turkey': '🇹🇷', 'Bulgaria': '🇧🇬', 'Romania': '🇷🇴', 'Georgia': '🇬🇪', 'Moldova': '🇲🇩', 'Switzerland': '🇨🇭', 'England': '🏴' } FOOD_EMOJIS = { # Meat - Specific Animals 'Beef': '🐄', 'Pork': '🐷', 'Lamb': '🐑', 'Veal': '🐄', 'Ham': '🐷', 'Poultry': '🐔', 'Chicken': '🐔', 'Duck': '🦆', 'Game Meat': '🦌', 'Meat': '🥩', # Cured/Processed Meats 'Cured Meat': '🥓', 'Cold Cuts': '🥪', 'Barbecue': '🔥', 'Grilled': '🔥', 'Roast': '🍖', # Fish & Seafood - Specific Types 'Rich Fish': '🐟', 'Lean Fish': '🐟', 'Fish': '🐟', 'Codfish': '🐟', 'Shellfish': '🦐', 'Seafood': '🦞', 'Sushi': '🍣', 'Sashimi': '🍣', # Cheese - Different Types 'Cheese': '🧀', 'Soft Cheese': '🧀', 'Hard Cheese': '🧀', 'Blue Cheese': '🧀', 'Goat Cheese': '🐐', 'Maturated Cheese': '🧀', 'Mild Cheese': '🧀', 'Medium-cured Cheese': '🧀', # Pasta & Italian 'Pasta': '🍝', 'Tagliatelle': '🍝', 'Lasagna': '🍝', 'Risotto': '🍚', 'Pizza': '🍕', 'Eggplant Parmigiana': '🍆', # Asian Food 'Asian Food': '🥢', 'Curry Chicken': '🍛', 'Yakissoba': '🍜', 'Paella': '🥘', # Vegetables & Vegetarian 'Vegetarian': '🥬', 'Salad': '🥗', 'Mushrooms': '🍄', 'Beans': '🫘', 'Baked Potato': '🥔', 'French Fries': '🍟', # Desserts & Sweets 'Sweet Dessert': '🍰', 'Dessert': '🍰', 'Fruit Dessert': '🍓', 'Cake': '🍰', 'Cookies': '🍪', 'Chocolate': '🍫', 'Cream': '🍨', 'Citric Dessert': '🍋', 'Spiced Fruit Cake': '🍰', # Fruits & Nuts 'Fruit': '🍇', 'Dried Fruits': '🍇', 'Chestnut': '🌰', # Appetizers & Snacks 'Appetizer': '🍤', 'Snack': '🥨', 'Aperitif': '🥂', # Soups & Stews 'Light Stews': '🍲', 'Soufflé': '🥄', # Dishes & Preparations 'Spicy Food': '🌶️', 'Tomato Dishes': '🍅' } # --- Data Download Function --- def download_data(): """Downloads the dataset from Google Drive if not already present.""" csv_filename = 'XWines_Full_100K_wines.csv' if os.path.exists(csv_filename): print(f"Using existing dataset: {csv_filename}") return csv_filename # Convert Google Drive share link to direct download link file_id = '1uEEipmKNxdiKUAhjH-K14JOSQ2BLRFss' download_url = f'https://drive.google.com/uc?export=download&id={file_id}' print(f"Downloading dataset from Google Drive...") try: response = requests.get(download_url, stream=True) response.raise_for_status() with open(csv_filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Dataset downloaded successfully: {csv_filename}") return csv_filename except Exception as e: raise Exception(f"Failed to download dataset: {str(e)}") # --- OPTIMIZATION 1: Data Loading & Pre-processing --- @lru_cache(maxsize=1) def load_and_preprocess_data(): """Loads and performs expensive one-time preprocessing on the dataset.""" start_time = time.time() print("[TIMING] Starting data loading and preprocessing...") csv_filename = download_data() print(f"[TIMING] File check/download completed in {time.time() - start_time:.2f}s") try: csv_start = time.time() print("[TIMING] Loading CSV data...") # Use efficient data types and only load needed columns if possible df = pd.read_csv(csv_filename, low_memory=False) print(f"[TIMING] CSV loaded in {time.time() - csv_start:.2f}s - {len(df):,} wine records") except FileNotFoundError: raise FileNotFoundError(f"CSV file '{csv_filename}' not found.") def parse_list_string(s): if not isinstance(s, str) or not s.strip(): return [] try: # Fast path for common patterns s = s.strip() if s.startswith('[') and s.endswith(']'): return ast.literal_eval(s) return [] except (ValueError, SyntaxError): return [] # Vectorized string processing for better performance parse_start = time.time() print("[TIMING] Starting string parsing...") df['grapes_list'] = df['Grapes'].fillna('[]').apply(parse_list_string) print(f"[TIMING] Grapes parsing completed in {time.time() - parse_start:.2f}s") harmonize_start = time.time() df['harmonize_list'] = df['Harmonize'].fillna('[]').apply(parse_list_string) print(f"[TIMING] Harmonize parsing completed in {time.time() - harmonize_start:.2f}s") derived_start = time.time() df['main_grape'] = df['grapes_list'].apply(lambda x: x[0] if x else 'Unknown') df['num_grapes'] = df['grapes_list'].apply(len) df['body_numeric'] = df['Body'].map(BODY_MAPPING) print(f"[TIMING] Derived columns completed in {time.time() - derived_start:.2f}s") total_time = time.time() - start_time print(f"[TIMING] Total preprocessing completed in {total_time:.2f}s") return df # --- OPTIMIZATION 2: Vectorized Data Aggregation --- def get_top_food_pairings(harmonize_list, top_n=3): """Get top N food pairings with emojis and names - optimized version.""" # Flatten list more efficiently all_pairings = [] for sublist in harmonize_list: if isinstance(sublist, list): all_pairings.extend(sublist) if not all_pairings: return {'emojis': '🍽️', 'names': 'General'} top_items = Counter(all_pairings).most_common(top_n) emojis = ''.join([FOOD_EMOJIS.get(item[0], '🍽️') for item in top_items]) names = ', '.join([item[0] for item in top_items]) return {'emojis': emojis, 'names': names} def aggregate_wine_data(df, wine_types, max_grape_count, min_samples_choice, regional_grouping): """Filters and aggregates wine data using efficient, vectorized pandas operations.""" agg_start = time.time() print(f"[TIMING] Starting aggregation with {len(df):,} records...") filtered_df = df.copy() if wine_types and 'All' not in wine_types: filtered_df = filtered_df[filtered_df['Type'].isin(wine_types)] if max_grape_count < 5: filtered_df = filtered_df[filtered_df['num_grapes'] <= max_grape_count] group_by_cols = ['main_grape', 'Type'] if regional_grouping: group_by_cols.append('Country') groupby_start = time.time() agg_df = filtered_df.groupby(group_by_cols).agg( count=('ABV', 'size'), avg_fullness=('body_numeric', 'mean'), abv_list=('ABV', list), body_list=('Body', list), acidity_list=('Acidity', list), harmonize_list=('harmonize_list', list), region_count=('RegionName', 'nunique'), winery_count=('WineryName', 'nunique') ).reset_index() print(f"[TIMING] GroupBy aggregation completed in {time.time() - groupby_start:.2f}s") min_samples = SAMPLE_THRESHOLDS[min_samples_choice] agg_df = agg_df[agg_df['count'] >= min_samples].copy() if agg_df.empty: return agg_df # Optimized distribution calculation def calc_distribution(values_list, categories): if not values_list: return {cat: 0.0 for cat in categories} counts = pd.Series(values_list).value_counts(normalize=True) * 100 return {cat: counts.get(cat, 0.0) for cat in categories} dist_start = time.time() agg_df['body_dist'] = agg_df['body_list'].apply( lambda x: calc_distribution(x, BODY_ORDER)) agg_df['acid_dist'] = agg_df['acidity_list'].apply( lambda x: calc_distribution(x, ACIDITY_ORDER)) print(f"[TIMING] Distribution calculations completed in {time.time() - dist_start:.2f}s") # Pre-compute food pairings more efficiently pairing_start = time.time() pairing_data = [] for harmonize_list in agg_df['harmonize_list']: pairing_data.append(get_top_food_pairings(harmonize_list)) agg_df['pairing_data'] = pairing_data agg_df['pairing_emoji'] = agg_df['pairing_data'].apply(lambda x: x['emojis']) agg_df['pairing_names'] = agg_df['pairing_data'].apply(lambda x: x['names']) print(f"[TIMING] Food pairing calculations completed in {time.time() - pairing_start:.2f}s") final_start = time.time() agg_df['wine_type_order'] = agg_df['Type'].map(WINE_TYPE_ORDER) agg_df = agg_df.sort_values(by=['wine_type_order', 'avg_fullness'], ascending=[False, True]) print(f"[TIMING] Final sorting completed in {time.time() - final_start:.2f}s") total_agg_time = time.time() - agg_start print(f"[TIMING] Total aggregation completed in {total_agg_time:.2f}s - {len(agg_df)} combinations") return agg_df # --- OPTIMIZATION 3: Efficient & Clean Chart Creation --- def create_wine_chart(chart_data, regional_grouping): """Creates the Plotly figure with optimized traces and layout.""" chart_start = time.time() print(f"[TIMING] Starting chart creation with {len(chart_data)} rows...") if chart_data.empty: fig = go.Figure() fig.add_annotation(text="No data available with current filters.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) return fig num_rows = len(chart_data) # Pre-compute labels more efficiently prep_start = time.time() wine_type_emojis = {'Red': '🍷', 'White': '🥂', 'Rosé': '🌸', 'Sparkling': '🍾'} chart_data = chart_data.copy() # Avoid SettingWithCopyWarning chart_data['wine_emoji'] = chart_data['Type'].map(wine_type_emojis).fillna('🍷') if regional_grouping: chart_data['flag'] = chart_data['Country'].map(COUNTRY_FLAGS).fillna('🌍') # Vectorized string concatenation instead of apply chart_data['grape_label'] = chart_data['wine_emoji'] + ' ' + chart_data['main_grape'] + ' ' + chart_data['flag'] else: chart_data['grape_label'] = chart_data['wine_emoji'] + ' ' + chart_data['main_grape'] print(f"[TIMING] Label preparation completed in {time.time() - prep_start:.2f}s") y_labels = chart_data['grape_label'].tolist() fig = make_subplots( rows=1, cols=5, specs=[[{}, {"type": "bar"}, {"type": "bar"}, {"type": "box"}, {}]], column_widths=[0.25, 0.22, 0.22, 0.16, 0.05], horizontal_spacing=0.025, shared_yaxes=True ) # Pre-compute hover texts more efficiently hover_start = time.time() country_part = chart_data['Country'] if regional_grouping else 'Global' hover_texts = ( '' + chart_data['main_grape'] + ' (' + country_part.astype(str) + ')
' + 'Wineries: ' + chart_data['winery_count'].astype(str) + '
' + 'Regions: ' + chart_data['region_count'].astype(str) + '
' + 'Total Wines: ' + chart_data['count'].apply(lambda x: f"{x:,}") ).tolist() print(f"[TIMING] Hover text preparation completed in {time.time() - hover_start:.2f}s") fig.add_trace(go.Bar( y=y_labels, x=[1] * num_rows, orientation='h', marker_color='rgba(0,0,0,0)', showlegend=False, hoverinfo='text', hovertext=hover_texts ), row=1, col=1) fig.add_trace(go.Scatter( y=y_labels, x=[0.03] * num_rows, mode='text', text=y_labels, textposition='middle right', textfont={'size': 22, 'color': '#1A1A1A'}, hoverinfo='none', showlegend=False ), row=1, col=1) # Optimize body distribution traces body_start = time.time() body_colors = {'Very light-bodied': '#FFB6C1', 'Light-bodied': '#CD5C5C', 'Medium-bodied': '#C13636', 'Full-bodied': '#8B0000', 'Very full-bodied': '#4B0000'} # Pre-extract all body values at once body_data = {} for body_type in BODY_ORDER: body_data[body_type] = [d.get(body_type, 0) for d in chart_data['body_dist']] for body_type in BODY_ORDER: fig.add_trace(go.Bar( y=y_labels, x=body_data[body_type], name=body_type, orientation='h', marker_color=body_colors.get(body_type), showlegend=False, hovertemplate=f"{body_type}: %{{x:.1f}}%" ), row=1, col=2) print(f"[TIMING] Body traces completed in {time.time() - body_start:.2f}s") # Optimize acidity distribution traces acid_start = time.time() acid_colors = {'Low': '#F5F5DC', 'Medium': '#DAA520', 'High': '#B8860B'} # Pre-extract all acidity values at once acid_data = {} for acid_type in ACIDITY_ORDER: acid_data[acid_type] = [d.get(acid_type, 0) for d in chart_data['acid_dist']] for acid_type in ACIDITY_ORDER: fig.add_trace(go.Bar( y=y_labels, x=acid_data[acid_type], name=acid_type, orientation='h', marker_color=acid_colors.get(acid_type), showlegend=False, hovertemplate=f"{acid_type} acidity: %{{x:.1f}}%" ), row=1, col=3) print(f"[TIMING] Acidity traces completed in {time.time() - acid_start:.2f}s") # Optimize box plot creation box_start = time.time() box_colors = {'Red': '#8B0000', 'White': '#DAA520', 'Rosé': '#CD5C5C', 'Sparkling': '#9370DB'} # Create box plots more efficiently for idx, (_, row) in enumerate(chart_data.iterrows()): abv_values = row['abv_list'] color = box_colors.get(row['Type'], '#6A5ACD') fig.add_trace(go.Box( y=[y_labels[idx]] * len(abv_values), x=abv_values, name=row['Type'], orientation='h', showlegend=False, marker_color=color, line_color=color, hovertemplate=f"ABV: %{{x:.1f}}%" ), row=1, col=4) print(f"[TIMING] Box plot traces completed in {time.time() - box_start:.2f}s") fig.add_trace(go.Scatter( y=y_labels, x=[0.5] * num_rows, mode='text', text=chart_data['pairing_emoji'], textposition='middle center', textfont={'size': 32}, showlegend=False, hoverinfo='text', hovertext=chart_data['pairing_names'] ), row=1, col=5) fig.update_layout( title={ 'text': "Wine Characteristics by Grape Variety", 'x': 0.5, 'font': {'size': 26, 'color': '#1A1A1A'} }, height=max(600, num_rows * 55), barmode='stack', showlegend=False, plot_bgcolor='#FAFAFA', paper_bgcolor='#E8E9EA', margin=dict(l=30, r=30, t=100, b=60), boxgap=0.3, bargap=0.2 ) column_titles = ["Grape Variety", "Body Profile (%)", "Acidity Profile (%)", "Alcohol (ABV %)", "Food Pairing"] # Add column titles title_start = time.time() for i, title in enumerate(column_titles, 1): domain = fig.layout[f'xaxis{i if i > 1 else ""}'].domain fig.add_annotation( x=(domain[0] + domain[1]) / 2, y=1.02, xref="paper", yref="paper", text=f"{title}", xanchor='center', showarrow=False, font={'size': 20, 'color': '#1A1A1A'} ) print(f"[TIMING] Column titles completed in {time.time() - title_start:.2f}s") # Update axes formatting axes_start = time.time() for i in range(1, 6): fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=i) # Show x-axis labels only for body, acidity, and ABV columns if i in [2, 3, 4]: fig.update_xaxes(showticklabels=True, showgrid=True, gridcolor='rgba(0,0,0,0.1)', zeroline=False, title_text="", row=1, col=i) else: fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, title_text="", row=1, col=i) print(f"[TIMING] Axes formatting completed in {time.time() - axes_start:.2f}s") # Final axis configuration final_axes_start = time.time() fig.update_yaxes(categoryorder="array", categoryarray=y_labels, autorange=False, range=[-0.5, num_rows - 0.5], row=1, col=1) print(f"[TIMING] Final axis configuration completed in {time.time() - final_axes_start:.2f}s") # Add alternating row backgrounds for better readability bg_start = time.time() for i in range(num_rows): if i % 2 == 1: fig.add_hrect(y0=i - 0.5, y1=i + 0.5, fillcolor="#F0F2F3", layer="below", line_width=0, row=1, col="all") print(f"[TIMING] Background rectangles completed in {time.time() - bg_start:.2f}s") chart_time = time.time() - chart_start print(f"[TIMING] Chart creation completed in {chart_time:.2f}s") return fig # --- Gradio Interface Logic --- def update_dashboard(wine_types, max_grape_count, min_samples_choice, regional_grouping, progress=gr.Progress(track_tqdm=True)): """Main function to update dashboard.""" dashboard_start = time.time() print(f"[TIMING] Starting dashboard update...") progress(0, desc="Loading and processing data...") df = load_and_preprocess_data() progress(0.5, desc="Filtering and aggregating...") chart_data = aggregate_wine_data(df, wine_types, max_grape_count, min_samples_choice, regional_grouping) progress(0.8, desc="Creating chart...") fig = create_wine_chart(chart_data, regional_grouping) total_combinations = len(chart_data) total_wines = chart_data['count'].sum() if not chart_data.empty else 0 min_samples = SAMPLE_THRESHOLDS[min_samples_choice] grouping_type = "grape+region" if regional_grouping else "grape+type" summary = f"📊 Showing **{total_combinations}** {grouping_type} combinations from **{total_wines:,}** wines (min {min_samples} samples each)" total_dashboard_time = time.time() - dashboard_start print(f"[TIMING] Total dashboard update completed in {total_dashboard_time:.2f}s") return fig, summary # Create Gradio interface def create_interface(): with gr.Blocks(title="Wine Analysis Dashboard", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🍷 Wine Characteristics Dashboard") with gr.Row(): wine_type_filter = gr.CheckboxGroup( choices=['Red', 'White', 'Rosé', 'Sparkling', 'Dessert', 'Dessert/Port'], value=['Red'], label="🍷 Wine Types" ) max_grape_slider = gr.Slider( minimum=1, maximum=5, step=1, value=1, label="🍇 Max Grapes per Wine", info="1: Varietals, 5: All Blends" ) min_samples_choice = gr.Radio( choices=list(SAMPLE_THRESHOLDS.keys()), value='Very Common (250+)', label="Minimum Sample Size (Wines per Variety)", ) regional_grouping = gr.Checkbox( value=True, label="Split By Country" ) summary_text = gr.Markdown() wine_plot = gr.Plot() inputs = [wine_type_filter, max_grape_slider, min_samples_choice, regional_grouping] outputs = [wine_plot, summary_text] # Auto-update when any input changes for input_component in inputs: input_component.change(update_dashboard, inputs=inputs, outputs=outputs) demo.load(fn=update_dashboard, inputs=inputs, outputs=outputs) return demo if __name__ == "__main__": app_interface = create_interface() app_interface.launch()