|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
import colorsys |
|
|
import random |
|
|
from dash import Dash, html, dcc, callback_context, Output, Input, State |
|
|
import dash_bootstrap_components as dbc |
|
|
from collections import defaultdict |
|
|
import os |
|
|
import logging |
|
|
import sys |
|
|
import json |
|
|
|
|
|
|
|
|
external_stylesheets = [dbc.themes.BOOTSTRAP] |
|
|
app = Dash(__name__, external_stylesheets=external_stylesheets, suppress_callback_exceptions=True) |
|
|
server = app.server |
|
|
|
|
|
|
|
|
data_path = "./gpt2_mdm_median_90ep_last_trained_30inf_batch8_expanded.csv" |
|
|
species_dict_path = "./dict_id2species.txt" |
|
|
|
|
|
|
|
|
|
|
|
def load_and_preprocess_data(file_path): |
|
|
"""Loads and prepares data for analysis.""" |
|
|
|
|
|
try: |
|
|
with open(species_dict_path, 'r') as f: |
|
|
species_dict = json.load(f) |
|
|
except: |
|
|
print(f"Error: Unable to load species dictionary at {species_dict_path}") |
|
|
species_dict = {} |
|
|
|
|
|
df = pd.read_csv(file_path) |
|
|
|
|
|
|
|
|
species_cols = [col for col in df.columns if col.startswith('SP')] |
|
|
|
|
|
|
|
|
sample_species = {} |
|
|
for _, row in df.iterrows(): |
|
|
sample_key = (row['sample_index'], row['sample_num']) |
|
|
species_present = [] |
|
|
|
|
|
for col in species_cols: |
|
|
if row[col] > 0: |
|
|
species_present.append((col, row[col])) |
|
|
|
|
|
|
|
|
species_present.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
if species_present: |
|
|
sample_species[sample_key] = species_present |
|
|
|
|
|
|
|
|
species_freq = defaultdict(int) |
|
|
for species_list in sample_species.values(): |
|
|
for species, _ in species_list: |
|
|
species_freq[species] += 1 |
|
|
|
|
|
|
|
|
sites = sorted(df['sample_index'].unique()) |
|
|
|
|
|
|
|
|
all_species = species_cols |
|
|
species_colors = generate_color_palette(len(all_species)) |
|
|
species_color_map = {sp: color for sp, color in zip(all_species, species_colors)} |
|
|
|
|
|
return { |
|
|
'df': df, |
|
|
'species_cols': species_cols, |
|
|
'sample_species': sample_species, |
|
|
'species_freq': dict(species_freq), |
|
|
'sites': sites, |
|
|
'top_species': sorted(species_freq.items(), key=lambda x: x[1], reverse=True), |
|
|
'species_dict': species_dict, |
|
|
'species_color_map': species_color_map |
|
|
} |
|
|
|
|
|
def generate_color_palette(n_colors, saturation=0.75, value=0.95, alpha=0.8): |
|
|
"""Generates a palette of n distinct colors with enhanced saturation and contrast.""" |
|
|
colors = [] |
|
|
|
|
|
|
|
|
enhanced_colors = [ |
|
|
f"rgba(204, 121, 167, {alpha})", |
|
|
f"rgba(86, 180, 86, {alpha})", |
|
|
f"rgba(213, 94, 94, {alpha})", |
|
|
f"rgba(86, 180, 180, {alpha})", |
|
|
f"rgba(215, 180, 76, {alpha})", |
|
|
f"rgba(120, 120, 204, {alpha})", |
|
|
f"rgba(225, 153, 76, {alpha})", |
|
|
f"rgba(153, 84, 204, {alpha})", |
|
|
f"rgba(86, 153, 204, {alpha})", |
|
|
f"rgba(204, 76, 153, {alpha})", |
|
|
f"rgba(153, 204, 76, {alpha})", |
|
|
f"rgba(229, 153, 153, {alpha})", |
|
|
f"rgba(76, 153, 76, {alpha})", |
|
|
f"rgba(153, 76, 76, {alpha})", |
|
|
f"rgba(76, 76, 153, {alpha})", |
|
|
f"rgba(172, 115, 57, {alpha})", |
|
|
f"rgba(204, 204, 57, {alpha})", |
|
|
f"rgba(204, 57, 204, {alpha})", |
|
|
f"rgba(57, 204, 204, {alpha})", |
|
|
f"rgba(229, 115, 57, {alpha})", |
|
|
f"rgba(120, 57, 204, {alpha})", |
|
|
f"rgba(57, 204, 120, {alpha})", |
|
|
f"rgba(204, 57, 115, {alpha})", |
|
|
f"rgba(153, 172, 230, {alpha})", |
|
|
f"rgba(230, 153, 172, {alpha})", |
|
|
f"rgba(172, 230, 153, {alpha})", |
|
|
f"rgba(230, 230, 153, {alpha})", |
|
|
f"rgba(153, 230, 230, {alpha})", |
|
|
f"rgba(230, 153, 230, {alpha})", |
|
|
f"rgba(132, 94, 57, {alpha})", |
|
|
f"rgba(76, 128, 57, {alpha})", |
|
|
f"rgba(57, 76, 128, {alpha})", |
|
|
f"rgba(128, 57, 76, {alpha})", |
|
|
f"rgba(57, 128, 94, {alpha})", |
|
|
f"rgba(94, 57, 128, {alpha})", |
|
|
f"rgba(235, 194, 57, {alpha})", |
|
|
f"rgba(57, 172, 172, {alpha})", |
|
|
f"rgba(172, 57, 102, {alpha})", |
|
|
f"rgba(102, 172, 57, {alpha})", |
|
|
f"rgba(235, 91, 172, {alpha})", |
|
|
] |
|
|
|
|
|
|
|
|
colors.extend(enhanced_colors[:min(len(enhanced_colors), n_colors)]) |
|
|
|
|
|
|
|
|
if n_colors > len(enhanced_colors): |
|
|
remaining = n_colors - len(enhanced_colors) |
|
|
|
|
|
golden_ratio_conjugate = 0.618033988749895 |
|
|
h = 0.5 |
|
|
|
|
|
for i in range(remaining): |
|
|
|
|
|
h = (h + golden_ratio_conjugate) % 1.0 |
|
|
|
|
|
rgb = colorsys.hsv_to_rgb(h, 0.75, 0.95) |
|
|
|
|
|
|
|
|
r = int(rgb[0] * 215 + 40) |
|
|
g = int(rgb[1] * 215 + 40) |
|
|
b = int(rgb[2] * 215 + 40) |
|
|
|
|
|
|
|
|
r = min(r, 255) |
|
|
g = min(g, 255) |
|
|
b = min(b, 255) |
|
|
|
|
|
color = f"rgba({r}, {g}, {b}, {alpha})" |
|
|
colors.append(color) |
|
|
|
|
|
return colors |
|
|
|
|
|
|
|
|
|
|
|
def create_all_sequences_sankey(data, filtered_sites=None, filtered_species=None, |
|
|
max_species=5, first_species_colors=True): |
|
|
""" |
|
|
Creates a Sankey diagram for all sequences. |
|
|
""" |
|
|
df = data['df'] |
|
|
sample_species = data['sample_species'] |
|
|
species_dict = data['species_dict'] |
|
|
species_color_map = data['species_color_map'] |
|
|
|
|
|
|
|
|
if filtered_sites: |
|
|
filtered_df = df[df['sample_index'].isin(filtered_sites)] |
|
|
filtered_keys = [(row['sample_index'], row['sample_num']) for _, row in filtered_df.iterrows()] |
|
|
filtered_samples = {k: v for k, v in sample_species.items() if k in filtered_keys} |
|
|
else: |
|
|
filtered_samples = sample_species |
|
|
|
|
|
|
|
|
links = [] |
|
|
node_colors = {} |
|
|
|
|
|
|
|
|
node_colors["Start"] = "rgba(100,100,100,0.8)" |
|
|
|
|
|
|
|
|
for (site, rep), species_list in filtered_samples.items(): |
|
|
|
|
|
if filtered_species: |
|
|
species_list = [(sp, val) for sp, val in species_list if sp in filtered_species] |
|
|
|
|
|
|
|
|
species_list = species_list[:max_species] |
|
|
|
|
|
if not species_list: |
|
|
continue |
|
|
|
|
|
|
|
|
for i, (current_species, current_abundance) in enumerate(species_list): |
|
|
|
|
|
if i == 0: |
|
|
source = "Start" |
|
|
else: |
|
|
prev_species = species_list[i-1][0] |
|
|
source = f"{i}_{prev_species}" |
|
|
|
|
|
|
|
|
target = f"{i+1}_{current_species}" |
|
|
|
|
|
|
|
|
if first_species_colors: |
|
|
first_sp = species_list[0][0] |
|
|
link_color = species_color_map.get(first_sp, "rgba(100,100,100,0.6)") |
|
|
else: |
|
|
link_color = species_color_map.get(current_species, "rgba(100,100,100,0.6)") |
|
|
|
|
|
|
|
|
links.append({ |
|
|
'source': source, |
|
|
'target': target, |
|
|
'value': current_abundance, |
|
|
'color': link_color |
|
|
}) |
|
|
|
|
|
|
|
|
node_colors[target] = species_color_map.get(current_species, "rgba(100,100,100,0.8)") |
|
|
|
|
|
|
|
|
return build_sankey_diagram(links, node_colors, title="Fish Species Sequences") |
|
|
|
|
|
def build_sankey_diagram(links, node_colors=None, title=""): |
|
|
""" |
|
|
Builds a Plotly Sankey diagram from links data. |
|
|
""" |
|
|
if not links: |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
x=0.5, y=0.5, |
|
|
xref="paper", yref="paper", |
|
|
text="No data to display", |
|
|
showarrow=False, |
|
|
font=dict(size=16, color="gray") |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
all_nodes = set() |
|
|
for link in links: |
|
|
all_nodes.add(link['source']) |
|
|
all_nodes.add(link['target']) |
|
|
|
|
|
node_list = sorted(list(all_nodes)) |
|
|
node_indices = {node: i for i, node in enumerate(node_list)} |
|
|
|
|
|
|
|
|
def format_node_label(node): |
|
|
if node == "Start": |
|
|
return "Start" |
|
|
|
|
|
|
|
|
parts = node.split('_', 1) |
|
|
if len(parts) == 2: |
|
|
position, species_code = parts |
|
|
return f"{position}: {species_code}" |
|
|
return node |
|
|
|
|
|
|
|
|
node_labels = [format_node_label(node) for node in node_list] |
|
|
|
|
|
|
|
|
if node_colors: |
|
|
node_color_list = [node_colors.get(node, "rgba(100,100,100,0.8)") for node in node_list] |
|
|
else: |
|
|
node_color_list = ["rgba(100,100,100,0.8)"] * len(node_list) |
|
|
|
|
|
|
|
|
source_indices = [node_indices[link['source']] for link in links] |
|
|
target_indices = [node_indices[link['target']] for link in links] |
|
|
values = [link['value'] for link in links] |
|
|
link_colors = [link.get('color', 'rgba(100,100,100,0.6)') for link in links] |
|
|
|
|
|
|
|
|
fig = go.Figure(data=[go.Sankey( |
|
|
node=dict( |
|
|
pad=15, |
|
|
thickness=20, |
|
|
line=dict(color="black", width=0.5), |
|
|
label=node_labels, |
|
|
color=node_color_list |
|
|
), |
|
|
link=dict( |
|
|
source=source_indices, |
|
|
target=target_indices, |
|
|
value=values, |
|
|
color=link_colors |
|
|
) |
|
|
)]) |
|
|
|
|
|
fig.update_layout( |
|
|
title_text=title, |
|
|
font_size=12, |
|
|
height=700, |
|
|
margin=dict(l=0, r=0, t=50, b=0) |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
try: |
|
|
data = load_and_preprocess_data(data_path) |
|
|
print(f"Data loaded successfully: {len(data['sites'])} sites, {len(data['species_cols'])} species") |
|
|
except Exception as e: |
|
|
print(f"Error loading data: {e}") |
|
|
data = None |
|
|
|
|
|
|
|
|
|
|
|
if data: |
|
|
site_options = [{'label': f'Site {site}', 'value': site} for site in data['sites']] |
|
|
else: |
|
|
site_options = [{'label': 'No data available', 'value': None}] |
|
|
|
|
|
app.layout = dbc.Container([ |
|
|
dbc.Row([ |
|
|
dbc.Col([ |
|
|
html.H1("Visualisation Sankey - Séquences d'Espèces", |
|
|
className="text-center mb-4", |
|
|
style={"color": "#2c3e50", "fontWeight": "bold"}) |
|
|
]) |
|
|
]), |
|
|
|
|
|
dbc.Row([ |
|
|
dbc.Col([ |
|
|
dbc.Card([ |
|
|
dbc.CardBody([ |
|
|
html.H5("Paramètres", className="card-title"), |
|
|
|
|
|
|
|
|
html.Label("Sélectionner un site:", className="fw-bold"), |
|
|
dcc.Dropdown( |
|
|
id='site-selector', |
|
|
options=site_options, |
|
|
value=None, |
|
|
placeholder="Choisir un site...", |
|
|
style={"marginBottom": "15px"} |
|
|
), |
|
|
|
|
|
|
|
|
html.Label("Nombre maximum d'espèces:", className="fw-bold"), |
|
|
dcc.Slider( |
|
|
id='max-species-slider', |
|
|
min=1, |
|
|
max=15, |
|
|
step=1, |
|
|
value=5, |
|
|
marks={i: str(i) for i in range(1, 16, 2)}, |
|
|
tooltip={"placement": "bottom", "always_visible": True} |
|
|
), |
|
|
|
|
|
html.Hr(), |
|
|
|
|
|
|
|
|
dbc.Row([ |
|
|
dbc.Col([ |
|
|
dbc.Button( |
|
|
"Mettre à jour le diagramme", |
|
|
id="update-button", |
|
|
color="primary", |
|
|
size="sm", |
|
|
className="w-100 mb-2", |
|
|
disabled=True |
|
|
) |
|
|
]) |
|
|
]), |
|
|
|
|
|
dbc.Row([ |
|
|
dbc.Col([ |
|
|
dbc.Button( |
|
|
"Exporter en HTML", |
|
|
id="export-button", |
|
|
color="success", |
|
|
size="sm", |
|
|
className="w-100", |
|
|
disabled=True |
|
|
), |
|
|
dcc.Download(id="download-html") |
|
|
]) |
|
|
]) |
|
|
]) |
|
|
]) |
|
|
], width=3), |
|
|
|
|
|
dbc.Col([ |
|
|
dbc.Card([ |
|
|
dbc.CardBody([ |
|
|
dcc.Graph( |
|
|
id='sankey-graph', |
|
|
style={'height': '700px'}, |
|
|
config={ |
|
|
'scrollZoom': True, |
|
|
'displayModeBar': True, |
|
|
'toImageButtonOptions': { |
|
|
'format': 'svg', |
|
|
'filename': 'sankey_diagram', |
|
|
'height': 700, |
|
|
'width': 1000, |
|
|
'scale': 2 |
|
|
} |
|
|
} |
|
|
) |
|
|
]) |
|
|
]) |
|
|
], width=9) |
|
|
]), |
|
|
|
|
|
dbc.Row([ |
|
|
dbc.Col([ |
|
|
html.Footer( |
|
|
"Application de visualisation Sankey pour les séquences d'espèces", |
|
|
className="text-center mt-4 mb-2 text-muted" |
|
|
) |
|
|
]) |
|
|
]) |
|
|
], fluid=True) |
|
|
|
|
|
|
|
|
|
|
|
@app.callback( |
|
|
Output('sankey-graph', 'figure'), |
|
|
Input('update-button', 'n_clicks'), |
|
|
State('site-selector', 'value'), |
|
|
State('max-species-slider', 'value'), |
|
|
prevent_initial_call=False |
|
|
) |
|
|
def update_sankey(n_clicks, selected_site, max_species): |
|
|
"""Updates the Sankey diagram based on selected parameters.""" |
|
|
ctx = callback_context |
|
|
is_initial = ctx.triggered_id is None |
|
|
|
|
|
if not data: |
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
x=0.5, y=0.5, |
|
|
xref="paper", yref="paper", |
|
|
text="Erreur: Impossible de charger les données", |
|
|
showarrow=False, |
|
|
font=dict(size=16, color="red") |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
filtered_sites = [selected_site] if selected_site is not None else None |
|
|
|
|
|
if is_initial: |
|
|
|
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
x=0.5, y=0.5, |
|
|
xref="paper", yref="paper", |
|
|
text="← Sélectionnez un site pour afficher le diagramme", |
|
|
showarrow=False, |
|
|
font=dict(size=20, color="#2c3e50") |
|
|
) |
|
|
fig.update_layout(height=700) |
|
|
return fig |
|
|
else: |
|
|
|
|
|
fig = create_all_sequences_sankey( |
|
|
data, |
|
|
filtered_sites=filtered_sites, |
|
|
filtered_species=None, |
|
|
max_species=max_species, |
|
|
first_species_colors=True |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
@app.callback( |
|
|
Output("download-html", "data"), |
|
|
Input("export-button", "n_clicks"), |
|
|
State('site-selector', 'value'), |
|
|
State('max-species-slider', 'value'), |
|
|
prevent_initial_call=True |
|
|
) |
|
|
def export_html(n_clicks, selected_site, max_species): |
|
|
"""Exports the current diagram as interactive HTML.""" |
|
|
if n_clicks is None or not data: |
|
|
return None |
|
|
|
|
|
|
|
|
filtered_sites = [selected_site] if selected_site is not None else None |
|
|
|
|
|
|
|
|
fig = create_all_sequences_sankey( |
|
|
data, |
|
|
filtered_sites=filtered_sites, |
|
|
filtered_species=None, |
|
|
max_species=max_species, |
|
|
first_species_colors=True |
|
|
) |
|
|
|
|
|
|
|
|
config = { |
|
|
'scrollZoom': True, |
|
|
'displayModeBar': True, |
|
|
'editable': True, |
|
|
'toImageButtonOptions': { |
|
|
'format': 'svg', |
|
|
'filename': 'sankey_diagram', |
|
|
'height': 800, |
|
|
'width': 1100, |
|
|
'scale': 2 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
html_str = fig.to_html(include_plotlyjs=True, full_html=True, config=config) |
|
|
|
|
|
|
|
|
return dict( |
|
|
content=html_str, |
|
|
filename=f"sankey_site_{selected_site}_{max_species}species.html" |
|
|
) |
|
|
|
|
|
@app.callback( |
|
|
[Output('update-button', 'disabled'), |
|
|
Output('export-button', 'disabled')], |
|
|
Input('site-selector', 'value') |
|
|
) |
|
|
def toggle_button_state(selected_site): |
|
|
"""Enables or disables buttons based on site selection.""" |
|
|
buttons_disabled = True if selected_site is None else False |
|
|
return buttons_disabled, buttons_disabled |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
port = int(os.environ.get('PORT', 7860)) |
|
|
app.run_server(debug=False, host='0.0.0.0', port=port) |