import pandas as pd import numpy as np import plotly.graph_objects as go import colorsys import random from dash import Dash, html, dcc, callback_context, Output, Input, State import dash_bootstrap_components as dbc from collections import defaultdict import os import logging import sys import json # Configuration pour Hugging Face Spaces external_stylesheets = [dbc.themes.BOOTSTRAP] app = Dash(__name__, external_stylesheets=external_stylesheets, suppress_callback_exceptions=True) server = app.server # Requis pour Hugging Face Spaces # Chemins des fichiers data_path = "./gpt2_mdm_median_90ep_last_trained_30inf_batch8_expanded.csv" species_dict_path = "./dict_id2species.txt" # ==================== DATA PROCESSING FUNCTIONS ==================== def load_and_preprocess_data(file_path): """Loads and prepares data for analysis.""" # Load the species dictionary try: with open(species_dict_path, 'r') as f: species_dict = json.load(f) except: print(f"Error: Unable to load species dictionary at {species_dict_path}") species_dict = {} df = pd.read_csv(file_path) # Identify species columns species_cols = [col for col in df.columns if col.startswith('SP')] # Create a dictionary of species present per sample sample_species = {} for _, row in df.iterrows(): sample_key = (row['sample_index'], row['sample_num']) species_present = [] for col in species_cols: if row[col] > 0: species_present.append((col, row[col])) # Sort by decreasing abundance species_present.sort(key=lambda x: x[1], reverse=True) # Store the species sequence if species_present: sample_species[sample_key] = species_present # Create a summary of the most frequent species species_freq = defaultdict(int) for species_list in sample_species.values(): for species, _ in species_list: species_freq[species] += 1 # Calculate site statistics sites = sorted(df['sample_index'].unique()) # Create color map for all species in the dataset all_species = species_cols species_colors = generate_color_palette(len(all_species)) species_color_map = {sp: color for sp, color in zip(all_species, species_colors)} return { 'df': df, 'species_cols': species_cols, 'sample_species': sample_species, 'species_freq': dict(species_freq), 'sites': sites, 'top_species': sorted(species_freq.items(), key=lambda x: x[1], reverse=True), 'species_dict': species_dict, 'species_color_map': species_color_map } def generate_color_palette(n_colors, saturation=0.75, value=0.95, alpha=0.8): """Generates a palette of n distinct colors with enhanced saturation and contrast.""" colors = [] # Pre-defined colors with slightly more saturation but still harmonious enhanced_colors = [ f"rgba(204, 121, 167, {alpha})", # Rose f"rgba(86, 180, 86, {alpha})", # Medium Green f"rgba(213, 94, 94, {alpha})", # Medium Red f"rgba(86, 180, 180, {alpha})", # Medium Teal f"rgba(215, 180, 76, {alpha})", # Medium Gold f"rgba(120, 120, 204, {alpha})", # Medium Blue f"rgba(225, 153, 76, {alpha})", # Medium Orange f"rgba(153, 84, 204, {alpha})", # Medium Purple f"rgba(86, 153, 204, {alpha})", # Medium Sky Blue f"rgba(204, 76, 153, {alpha})", # Medium Magenta f"rgba(153, 204, 76, {alpha})", # Medium Lime f"rgba(229, 153, 153, {alpha})", # Medium Coral f"rgba(76, 153, 76, {alpha})", # Forest Green f"rgba(153, 76, 76, {alpha})", # Brick Red f"rgba(76, 76, 153, {alpha})", # Navy Blue f"rgba(172, 115, 57, {alpha})", # Brown f"rgba(204, 204, 57, {alpha})", # Enhanced Yellow f"rgba(204, 57, 204, {alpha})", # Enhanced Purple f"rgba(57, 204, 204, {alpha})", # Enhanced Cyan f"rgba(229, 115, 57, {alpha})", # Enhanced Orange f"rgba(120, 57, 204, {alpha})", # Enhanced Indigo f"rgba(57, 204, 120, {alpha})", # Enhanced Jade f"rgba(204, 57, 115, {alpha})", # Enhanced Pink f"rgba(153, 172, 230, {alpha})", # Soft Periwinkle f"rgba(230, 153, 172, {alpha})", # Soft Pink f"rgba(172, 230, 153, {alpha})", # Soft Mint f"rgba(230, 230, 153, {alpha})", # Soft Yellow f"rgba(153, 230, 230, {alpha})", # Soft Cyan f"rgba(230, 153, 230, {alpha})", # Soft Lavender f"rgba(132, 94, 57, {alpha})", # Sienna f"rgba(76, 128, 57, {alpha})", # Olive Green f"rgba(57, 76, 128, {alpha})", # Slate Blue f"rgba(128, 57, 76, {alpha})", # Burgundy f"rgba(57, 128, 94, {alpha})", # Deep Teal f"rgba(94, 57, 128, {alpha})", # Amethyst f"rgba(235, 194, 57, {alpha})", # Amber Gold f"rgba(57, 172, 172, {alpha})", # Deep Aqua f"rgba(172, 57, 102, {alpha})", # Raspberry f"rgba(102, 172, 57, {alpha})", # Apple Green f"rgba(235, 91, 172, {alpha})", # Hot Pink ] # Use pre-defined colors first colors.extend(enhanced_colors[:min(len(enhanced_colors), n_colors)]) # If we need more colors, generate them algorithmically with enhanced saturation if n_colors > len(enhanced_colors): remaining = n_colors - len(enhanced_colors) # Use golden ratio to maximize hue difference golden_ratio_conjugate = 0.618033988749895 h = 0.5 # Start with somewhat random hue for i in range(remaining): # Use golden ratio to generate distinct hues h = (h + golden_ratio_conjugate) % 1.0 # Use higher saturation and value for more vibrant colors rgb = colorsys.hsv_to_rgb(h, 0.75, 0.95) # Scale RGB values to maintain enhanced contrast r = int(rgb[0] * 215 + 40) # Base of 40 for more saturation g = int(rgb[1] * 215 + 40) b = int(rgb[2] * 215 + 40) # Clamp values to valid range r = min(r, 255) g = min(g, 255) b = min(b, 255) color = f"rgba({r}, {g}, {b}, {alpha})" colors.append(color) return colors # ==================== SANKEY DIAGRAM GENERATION FUNCTIONS ==================== def create_all_sequences_sankey(data, filtered_sites=None, filtered_species=None, max_species=5, first_species_colors=True): """ Creates a Sankey diagram for all sequences. """ df = data['df'] sample_species = data['sample_species'] species_dict = data['species_dict'] species_color_map = data['species_color_map'] # Filter by site if necessary if filtered_sites: filtered_df = df[df['sample_index'].isin(filtered_sites)] filtered_keys = [(row['sample_index'], row['sample_num']) for _, row in filtered_df.iterrows()] filtered_samples = {k: v for k, v in sample_species.items() if k in filtered_keys} else: filtered_samples = sample_species # Prepare data for the diagram links = [] node_colors = {} # Store colors for each node # Set color for the Start node to gray node_colors["Start"] = "rgba(100,100,100,0.8)" # For each sample for (site, rep), species_list in filtered_samples.items(): # Filter by species if necessary if filtered_species: species_list = [(sp, val) for sp, val in species_list if sp in filtered_species] # Limit to max_species species_list = species_list[:max_species] if not species_list: continue # Create sequence of links for i, (current_species, current_abundance) in enumerate(species_list): # Source node name if i == 0: source = "Start" else: prev_species = species_list[i-1][0] source = f"{i}_{prev_species}" # Target node name target = f"{i+1}_{current_species}" # Determine color for this link if first_species_colors: first_sp = species_list[0][0] link_color = species_color_map.get(first_sp, "rgba(100,100,100,0.6)") else: link_color = species_color_map.get(current_species, "rgba(100,100,100,0.6)") # Add link links.append({ 'source': source, 'target': target, 'value': current_abundance, 'color': link_color }) # Store color for the target node node_colors[target] = species_color_map.get(current_species, "rgba(100,100,100,0.8)") # Create and return the Sankey diagram return build_sankey_diagram(links, node_colors, title="Fish Species Sequences") def build_sankey_diagram(links, node_colors=None, title=""): """ Builds a Plotly Sankey diagram from links data. """ if not links: fig = go.Figure() fig.add_annotation( x=0.5, y=0.5, xref="paper", yref="paper", text="No data to display", showarrow=False, font=dict(size=16, color="gray") ) return fig # Create nodes list all_nodes = set() for link in links: all_nodes.add(link['source']) all_nodes.add(link['target']) node_list = sorted(list(all_nodes)) node_indices = {node: i for i, node in enumerate(node_list)} # Convert species codes to readable names def format_node_label(node): if node == "Start": return "Start" # Extract species part (after the position number and underscore) parts = node.split('_', 1) if len(parts) == 2: position, species_code = parts return f"{position}: {species_code}" return node # Prepare node labels node_labels = [format_node_label(node) for node in node_list] # Prepare node colors if node_colors: node_color_list = [node_colors.get(node, "rgba(100,100,100,0.8)") for node in node_list] else: node_color_list = ["rgba(100,100,100,0.8)"] * len(node_list) # Prepare links for Sankey source_indices = [node_indices[link['source']] for link in links] target_indices = [node_indices[link['target']] for link in links] values = [link['value'] for link in links] link_colors = [link.get('color', 'rgba(100,100,100,0.6)') for link in links] # Create the Sankey diagram fig = go.Figure(data=[go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=node_labels, color=node_color_list ), link=dict( source=source_indices, target=target_indices, value=values, color=link_colors ) )]) fig.update_layout( title_text=title, font_size=12, height=700, margin=dict(l=0, r=0, t=50, b=0) ) return fig # Load data at startup try: data = load_and_preprocess_data(data_path) print(f"Data loaded successfully: {len(data['sites'])} sites, {len(data['species_cols'])} species") except Exception as e: print(f"Error loading data: {e}") data = None # ==================== DASH APP LAYOUT ==================== if data: site_options = [{'label': f'Site {site}', 'value': site} for site in data['sites']] else: site_options = [{'label': 'No data available', 'value': None}] app.layout = dbc.Container([ dbc.Row([ dbc.Col([ html.H1("Visualisation Sankey - Séquences d'Espèces", className="text-center mb-4", style={"color": "#2c3e50", "fontWeight": "bold"}) ]) ]), dbc.Row([ dbc.Col([ dbc.Card([ dbc.CardBody([ html.H5("Paramètres", className="card-title"), # Site selector html.Label("Sélectionner un site:", className="fw-bold"), dcc.Dropdown( id='site-selector', options=site_options, value=None, placeholder="Choisir un site...", style={"marginBottom": "15px"} ), # Max species slider html.Label("Nombre maximum d'espèces:", className="fw-bold"), dcc.Slider( id='max-species-slider', min=1, max=15, step=1, value=5, marks={i: str(i) for i in range(1, 16, 2)}, tooltip={"placement": "bottom", "always_visible": True} ), html.Hr(), # Buttons dbc.Row([ dbc.Col([ dbc.Button( "Mettre à jour le diagramme", id="update-button", color="primary", size="sm", className="w-100 mb-2", disabled=True ) ]) ]), dbc.Row([ dbc.Col([ dbc.Button( "Exporter en HTML", id="export-button", color="success", size="sm", className="w-100", disabled=True ), dcc.Download(id="download-html") ]) ]) ]) ]) ], width=3), dbc.Col([ dbc.Card([ dbc.CardBody([ dcc.Graph( id='sankey-graph', style={'height': '700px'}, config={ 'scrollZoom': True, 'displayModeBar': True, 'toImageButtonOptions': { 'format': 'svg', 'filename': 'sankey_diagram', 'height': 700, 'width': 1000, 'scale': 2 } } ) ]) ]) ], width=9) ]), dbc.Row([ dbc.Col([ html.Footer( "Application de visualisation Sankey pour les séquences d'espèces", className="text-center mt-4 mb-2 text-muted" ) ]) ]) ], fluid=True) # ==================== DASH CALLBACKS ==================== @app.callback( Output('sankey-graph', 'figure'), Input('update-button', 'n_clicks'), State('site-selector', 'value'), State('max-species-slider', 'value'), prevent_initial_call=False ) def update_sankey(n_clicks, selected_site, max_species): """Updates the Sankey diagram based on selected parameters.""" ctx = callback_context is_initial = ctx.triggered_id is None if not data: fig = go.Figure() fig.add_annotation( x=0.5, y=0.5, xref="paper", yref="paper", text="Erreur: Impossible de charger les données", showarrow=False, font=dict(size=16, color="red") ) return fig # Prepare filtered sites list filtered_sites = [selected_site] if selected_site is not None else None if is_initial: # For initial loading, display instruction message fig = go.Figure() fig.add_annotation( x=0.5, y=0.5, xref="paper", yref="paper", text="← Sélectionnez un site pour afficher le diagramme", showarrow=False, font=dict(size=20, color="#2c3e50") ) fig.update_layout(height=700) return fig else: # Create the Sankey diagram fig = create_all_sequences_sankey( data, filtered_sites=filtered_sites, filtered_species=None, max_species=max_species, first_species_colors=True ) return fig @app.callback( Output("download-html", "data"), Input("export-button", "n_clicks"), State('site-selector', 'value'), State('max-species-slider', 'value'), prevent_initial_call=True ) def export_html(n_clicks, selected_site, max_species): """Exports the current diagram as interactive HTML.""" if n_clicks is None or not data: return None # Prepare filtered sites list filtered_sites = [selected_site] if selected_site is not None else None # Create the figure fig = create_all_sequences_sankey( data, filtered_sites=filtered_sites, filtered_species=None, max_species=max_species, first_species_colors=True ) # Export configuration config = { 'scrollZoom': True, 'displayModeBar': True, 'editable': True, 'toImageButtonOptions': { 'format': 'svg', 'filename': 'sankey_diagram', 'height': 800, 'width': 1100, 'scale': 2 } } # Create HTML html_str = fig.to_html(include_plotlyjs=True, full_html=True, config=config) # Return content for download return dict( content=html_str, filename=f"sankey_site_{selected_site}_{max_species}species.html" ) @app.callback( [Output('update-button', 'disabled'), Output('export-button', 'disabled')], Input('site-selector', 'value') ) def toggle_button_state(selected_site): """Enables or disables buttons based on site selection.""" buttons_disabled = True if selected_site is None else False return buttons_disabled, buttons_disabled # Entry point pour Hugging Face Spaces if __name__ == '__main__': # Configuration pour le déploiement port = int(os.environ.get('PORT', 7860)) # Port par défaut HF Spaces app.run_server(debug=False, host='0.0.0.0', port=port)