sankey / app.py
Jul34's picture
🚀 Deploy Sankey Species Visualization App
996635f
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import colorsys
import random
from dash import Dash, html, dcc, callback_context, Output, Input, State
import dash_bootstrap_components as dbc
from collections import defaultdict
import os
import logging
import sys
import json
# Configuration pour Hugging Face Spaces
external_stylesheets = [dbc.themes.BOOTSTRAP]
app = Dash(__name__, external_stylesheets=external_stylesheets, suppress_callback_exceptions=True)
server = app.server # Requis pour Hugging Face Spaces
# Chemins des fichiers
data_path = "./gpt2_mdm_median_90ep_last_trained_30inf_batch8_expanded.csv"
species_dict_path = "./dict_id2species.txt"
# ==================== DATA PROCESSING FUNCTIONS ====================
def load_and_preprocess_data(file_path):
"""Loads and prepares data for analysis."""
# Load the species dictionary
try:
with open(species_dict_path, 'r') as f:
species_dict = json.load(f)
except:
print(f"Error: Unable to load species dictionary at {species_dict_path}")
species_dict = {}
df = pd.read_csv(file_path)
# Identify species columns
species_cols = [col for col in df.columns if col.startswith('SP')]
# Create a dictionary of species present per sample
sample_species = {}
for _, row in df.iterrows():
sample_key = (row['sample_index'], row['sample_num'])
species_present = []
for col in species_cols:
if row[col] > 0:
species_present.append((col, row[col]))
# Sort by decreasing abundance
species_present.sort(key=lambda x: x[1], reverse=True)
# Store the species sequence
if species_present:
sample_species[sample_key] = species_present
# Create a summary of the most frequent species
species_freq = defaultdict(int)
for species_list in sample_species.values():
for species, _ in species_list:
species_freq[species] += 1
# Calculate site statistics
sites = sorted(df['sample_index'].unique())
# Create color map for all species in the dataset
all_species = species_cols
species_colors = generate_color_palette(len(all_species))
species_color_map = {sp: color for sp, color in zip(all_species, species_colors)}
return {
'df': df,
'species_cols': species_cols,
'sample_species': sample_species,
'species_freq': dict(species_freq),
'sites': sites,
'top_species': sorted(species_freq.items(), key=lambda x: x[1], reverse=True),
'species_dict': species_dict,
'species_color_map': species_color_map
}
def generate_color_palette(n_colors, saturation=0.75, value=0.95, alpha=0.8):
"""Generates a palette of n distinct colors with enhanced saturation and contrast."""
colors = []
# Pre-defined colors with slightly more saturation but still harmonious
enhanced_colors = [
f"rgba(204, 121, 167, {alpha})", # Rose
f"rgba(86, 180, 86, {alpha})", # Medium Green
f"rgba(213, 94, 94, {alpha})", # Medium Red
f"rgba(86, 180, 180, {alpha})", # Medium Teal
f"rgba(215, 180, 76, {alpha})", # Medium Gold
f"rgba(120, 120, 204, {alpha})", # Medium Blue
f"rgba(225, 153, 76, {alpha})", # Medium Orange
f"rgba(153, 84, 204, {alpha})", # Medium Purple
f"rgba(86, 153, 204, {alpha})", # Medium Sky Blue
f"rgba(204, 76, 153, {alpha})", # Medium Magenta
f"rgba(153, 204, 76, {alpha})", # Medium Lime
f"rgba(229, 153, 153, {alpha})", # Medium Coral
f"rgba(76, 153, 76, {alpha})", # Forest Green
f"rgba(153, 76, 76, {alpha})", # Brick Red
f"rgba(76, 76, 153, {alpha})", # Navy Blue
f"rgba(172, 115, 57, {alpha})", # Brown
f"rgba(204, 204, 57, {alpha})", # Enhanced Yellow
f"rgba(204, 57, 204, {alpha})", # Enhanced Purple
f"rgba(57, 204, 204, {alpha})", # Enhanced Cyan
f"rgba(229, 115, 57, {alpha})", # Enhanced Orange
f"rgba(120, 57, 204, {alpha})", # Enhanced Indigo
f"rgba(57, 204, 120, {alpha})", # Enhanced Jade
f"rgba(204, 57, 115, {alpha})", # Enhanced Pink
f"rgba(153, 172, 230, {alpha})", # Soft Periwinkle
f"rgba(230, 153, 172, {alpha})", # Soft Pink
f"rgba(172, 230, 153, {alpha})", # Soft Mint
f"rgba(230, 230, 153, {alpha})", # Soft Yellow
f"rgba(153, 230, 230, {alpha})", # Soft Cyan
f"rgba(230, 153, 230, {alpha})", # Soft Lavender
f"rgba(132, 94, 57, {alpha})", # Sienna
f"rgba(76, 128, 57, {alpha})", # Olive Green
f"rgba(57, 76, 128, {alpha})", # Slate Blue
f"rgba(128, 57, 76, {alpha})", # Burgundy
f"rgba(57, 128, 94, {alpha})", # Deep Teal
f"rgba(94, 57, 128, {alpha})", # Amethyst
f"rgba(235, 194, 57, {alpha})", # Amber Gold
f"rgba(57, 172, 172, {alpha})", # Deep Aqua
f"rgba(172, 57, 102, {alpha})", # Raspberry
f"rgba(102, 172, 57, {alpha})", # Apple Green
f"rgba(235, 91, 172, {alpha})", # Hot Pink
]
# Use pre-defined colors first
colors.extend(enhanced_colors[:min(len(enhanced_colors), n_colors)])
# If we need more colors, generate them algorithmically with enhanced saturation
if n_colors > len(enhanced_colors):
remaining = n_colors - len(enhanced_colors)
# Use golden ratio to maximize hue difference
golden_ratio_conjugate = 0.618033988749895
h = 0.5 # Start with somewhat random hue
for i in range(remaining):
# Use golden ratio to generate distinct hues
h = (h + golden_ratio_conjugate) % 1.0
# Use higher saturation and value for more vibrant colors
rgb = colorsys.hsv_to_rgb(h, 0.75, 0.95)
# Scale RGB values to maintain enhanced contrast
r = int(rgb[0] * 215 + 40) # Base of 40 for more saturation
g = int(rgb[1] * 215 + 40)
b = int(rgb[2] * 215 + 40)
# Clamp values to valid range
r = min(r, 255)
g = min(g, 255)
b = min(b, 255)
color = f"rgba({r}, {g}, {b}, {alpha})"
colors.append(color)
return colors
# ==================== SANKEY DIAGRAM GENERATION FUNCTIONS ====================
def create_all_sequences_sankey(data, filtered_sites=None, filtered_species=None,
max_species=5, first_species_colors=True):
"""
Creates a Sankey diagram for all sequences.
"""
df = data['df']
sample_species = data['sample_species']
species_dict = data['species_dict']
species_color_map = data['species_color_map']
# Filter by site if necessary
if filtered_sites:
filtered_df = df[df['sample_index'].isin(filtered_sites)]
filtered_keys = [(row['sample_index'], row['sample_num']) for _, row in filtered_df.iterrows()]
filtered_samples = {k: v for k, v in sample_species.items() if k in filtered_keys}
else:
filtered_samples = sample_species
# Prepare data for the diagram
links = []
node_colors = {} # Store colors for each node
# Set color for the Start node to gray
node_colors["Start"] = "rgba(100,100,100,0.8)"
# For each sample
for (site, rep), species_list in filtered_samples.items():
# Filter by species if necessary
if filtered_species:
species_list = [(sp, val) for sp, val in species_list if sp in filtered_species]
# Limit to max_species
species_list = species_list[:max_species]
if not species_list:
continue
# Create sequence of links
for i, (current_species, current_abundance) in enumerate(species_list):
# Source node name
if i == 0:
source = "Start"
else:
prev_species = species_list[i-1][0]
source = f"{i}_{prev_species}"
# Target node name
target = f"{i+1}_{current_species}"
# Determine color for this link
if first_species_colors:
first_sp = species_list[0][0]
link_color = species_color_map.get(first_sp, "rgba(100,100,100,0.6)")
else:
link_color = species_color_map.get(current_species, "rgba(100,100,100,0.6)")
# Add link
links.append({
'source': source,
'target': target,
'value': current_abundance,
'color': link_color
})
# Store color for the target node
node_colors[target] = species_color_map.get(current_species, "rgba(100,100,100,0.8)")
# Create and return the Sankey diagram
return build_sankey_diagram(links, node_colors, title="Fish Species Sequences")
def build_sankey_diagram(links, node_colors=None, title=""):
"""
Builds a Plotly Sankey diagram from links data.
"""
if not links:
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
xref="paper", yref="paper",
text="No data to display",
showarrow=False,
font=dict(size=16, color="gray")
)
return fig
# Create nodes list
all_nodes = set()
for link in links:
all_nodes.add(link['source'])
all_nodes.add(link['target'])
node_list = sorted(list(all_nodes))
node_indices = {node: i for i, node in enumerate(node_list)}
# Convert species codes to readable names
def format_node_label(node):
if node == "Start":
return "Start"
# Extract species part (after the position number and underscore)
parts = node.split('_', 1)
if len(parts) == 2:
position, species_code = parts
return f"{position}: {species_code}"
return node
# Prepare node labels
node_labels = [format_node_label(node) for node in node_list]
# Prepare node colors
if node_colors:
node_color_list = [node_colors.get(node, "rgba(100,100,100,0.8)") for node in node_list]
else:
node_color_list = ["rgba(100,100,100,0.8)"] * len(node_list)
# Prepare links for Sankey
source_indices = [node_indices[link['source']] for link in links]
target_indices = [node_indices[link['target']] for link in links]
values = [link['value'] for link in links]
link_colors = [link.get('color', 'rgba(100,100,100,0.6)') for link in links]
# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=node_labels,
color=node_color_list
),
link=dict(
source=source_indices,
target=target_indices,
value=values,
color=link_colors
)
)])
fig.update_layout(
title_text=title,
font_size=12,
height=700,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig
# Load data at startup
try:
data = load_and_preprocess_data(data_path)
print(f"Data loaded successfully: {len(data['sites'])} sites, {len(data['species_cols'])} species")
except Exception as e:
print(f"Error loading data: {e}")
data = None
# ==================== DASH APP LAYOUT ====================
if data:
site_options = [{'label': f'Site {site}', 'value': site} for site in data['sites']]
else:
site_options = [{'label': 'No data available', 'value': None}]
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("Visualisation Sankey - Séquences d'Espèces",
className="text-center mb-4",
style={"color": "#2c3e50", "fontWeight": "bold"})
])
]),
dbc.Row([
dbc.Col([
dbc.Card([
dbc.CardBody([
html.H5("Paramètres", className="card-title"),
# Site selector
html.Label("Sélectionner un site:", className="fw-bold"),
dcc.Dropdown(
id='site-selector',
options=site_options,
value=None,
placeholder="Choisir un site...",
style={"marginBottom": "15px"}
),
# Max species slider
html.Label("Nombre maximum d'espèces:", className="fw-bold"),
dcc.Slider(
id='max-species-slider',
min=1,
max=15,
step=1,
value=5,
marks={i: str(i) for i in range(1, 16, 2)},
tooltip={"placement": "bottom", "always_visible": True}
),
html.Hr(),
# Buttons
dbc.Row([
dbc.Col([
dbc.Button(
"Mettre à jour le diagramme",
id="update-button",
color="primary",
size="sm",
className="w-100 mb-2",
disabled=True
)
])
]),
dbc.Row([
dbc.Col([
dbc.Button(
"Exporter en HTML",
id="export-button",
color="success",
size="sm",
className="w-100",
disabled=True
),
dcc.Download(id="download-html")
])
])
])
])
], width=3),
dbc.Col([
dbc.Card([
dbc.CardBody([
dcc.Graph(
id='sankey-graph',
style={'height': '700px'},
config={
'scrollZoom': True,
'displayModeBar': True,
'toImageButtonOptions': {
'format': 'svg',
'filename': 'sankey_diagram',
'height': 700,
'width': 1000,
'scale': 2
}
}
)
])
])
], width=9)
]),
dbc.Row([
dbc.Col([
html.Footer(
"Application de visualisation Sankey pour les séquences d'espèces",
className="text-center mt-4 mb-2 text-muted"
)
])
])
], fluid=True)
# ==================== DASH CALLBACKS ====================
@app.callback(
Output('sankey-graph', 'figure'),
Input('update-button', 'n_clicks'),
State('site-selector', 'value'),
State('max-species-slider', 'value'),
prevent_initial_call=False
)
def update_sankey(n_clicks, selected_site, max_species):
"""Updates the Sankey diagram based on selected parameters."""
ctx = callback_context
is_initial = ctx.triggered_id is None
if not data:
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
xref="paper", yref="paper",
text="Erreur: Impossible de charger les données",
showarrow=False,
font=dict(size=16, color="red")
)
return fig
# Prepare filtered sites list
filtered_sites = [selected_site] if selected_site is not None else None
if is_initial:
# For initial loading, display instruction message
fig = go.Figure()
fig.add_annotation(
x=0.5, y=0.5,
xref="paper", yref="paper",
text="← Sélectionnez un site pour afficher le diagramme",
showarrow=False,
font=dict(size=20, color="#2c3e50")
)
fig.update_layout(height=700)
return fig
else:
# Create the Sankey diagram
fig = create_all_sequences_sankey(
data,
filtered_sites=filtered_sites,
filtered_species=None,
max_species=max_species,
first_species_colors=True
)
return fig
@app.callback(
Output("download-html", "data"),
Input("export-button", "n_clicks"),
State('site-selector', 'value'),
State('max-species-slider', 'value'),
prevent_initial_call=True
)
def export_html(n_clicks, selected_site, max_species):
"""Exports the current diagram as interactive HTML."""
if n_clicks is None or not data:
return None
# Prepare filtered sites list
filtered_sites = [selected_site] if selected_site is not None else None
# Create the figure
fig = create_all_sequences_sankey(
data,
filtered_sites=filtered_sites,
filtered_species=None,
max_species=max_species,
first_species_colors=True
)
# Export configuration
config = {
'scrollZoom': True,
'displayModeBar': True,
'editable': True,
'toImageButtonOptions': {
'format': 'svg',
'filename': 'sankey_diagram',
'height': 800,
'width': 1100,
'scale': 2
}
}
# Create HTML
html_str = fig.to_html(include_plotlyjs=True, full_html=True, config=config)
# Return content for download
return dict(
content=html_str,
filename=f"sankey_site_{selected_site}_{max_species}species.html"
)
@app.callback(
[Output('update-button', 'disabled'),
Output('export-button', 'disabled')],
Input('site-selector', 'value')
)
def toggle_button_state(selected_site):
"""Enables or disables buttons based on site selection."""
buttons_disabled = True if selected_site is None else False
return buttons_disabled, buttons_disabled
# Entry point pour Hugging Face Spaces
if __name__ == '__main__':
# Configuration pour le déploiement
port = int(os.environ.get('PORT', 7860)) # Port par défaut HF Spaces
app.run_server(debug=False, host='0.0.0.0', port=port)