PanGenomeWatchAI / src /callbacks.py
Ashkan Taghipour (The University of Western Australia)
Improve Genome Explorer heatmap: angle x-axis labels, limit tick count
1e42f44
"""All Gradio callbacks for the Pigeon Pea Pangenome Atlas."""
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from src.state import AppState
from src.gene_card import build_gene_card, render_gene_card_html, export_gene_report
from src.field_report import generate_field_report, export_report_json, export_report_csv
# Color palettes
CORE_COLORS = {"core": "#2E7D32", "shell": "#FFC107", "cloud": "#F44336", "unknown": "#9E9E9E"}
COUNTRY_COLORS = px.colors.qualitative.Set3
# ============================================================
# Quest 0 Callbacks
# ============================================================
def on_line_selected(line_id: str, state: AppState, data: dict) -> tuple:
"""
Triggered by dropdown change.
Returns: (total_genes, unique_genes, nearest_neighbor, updated_state)
"""
if not line_id or state is None:
state = AppState()
state.selected_line = line_id
line_stats = data["line_stats"]
similarity = data["similarity"]
row = line_stats[line_stats["line_id"] == line_id]
if len(row) == 0:
return "--", "--", "--", state
total_genes = str(int(row.iloc[0]["genes_present_count"]))
unique_genes = str(int(row.iloc[0]["unique_genes_count"]))
# Nearest neighbor
sim_rows = similarity[similarity["line_id"] == line_id]
if len(sim_rows) > 0:
top = sim_rows.nlargest(1, "jaccard_score").iloc[0]
nearest = f"{top['neighbor_line_id']} ({top['jaccard_score']:.3f})"
else:
nearest = "--"
return total_genes, unique_genes, nearest, state
def on_start_journey(state: AppState) -> tuple:
"""Award Explorer achievement and switch to Quest 1."""
if state is None:
state = AppState()
state.award("Explorer")
return gr.Tabs(selected="quest1"), state
# ============================================================
# Quest 1 Callbacks
# ============================================================
def build_umap_plot(color_by: str, state: AppState, data: dict) -> go.Figure:
"""Build 3D UMAP scatter (delegates to quest1.build_umap_3d)."""
from ui.quest1 import build_umap_3d
selected_line = state.selected_line if state else None
color_key = "country" if color_by == "Country" else "cluster"
return build_umap_3d(
data["embedding"], data["line_stats"],
color_by=color_key, selected_line=selected_line,
)
def on_umap_select(selected_data, state: AppState) -> tuple:
"""Handle UMAP point selection."""
if state is None:
state = AppState()
if selected_data and "points" in selected_data:
selected_lines = [p.get("hovertext", p.get("text", "")) for p in selected_data["points"]]
selected_lines = [l for l in selected_lines if l]
state.selected_party = selected_lines
party_text = f"Selected {len(selected_lines)} lines: " + ", ".join(selected_lines[:10])
if len(selected_lines) > 10:
party_text += f" ... +{len(selected_lines) - 10} more"
else:
state.selected_party = []
party_text = "None selected"
return party_text, state
def on_compare_party(state: AppState, data: dict) -> tuple:
"""Compare selected line vs party."""
if not state or not state.selected_line or not state.selected_party:
fig = go.Figure()
fig.add_annotation(text="Select your line and a party first", showarrow=False)
return fig, True
pav = data.get("pav")
if pav is None:
fig = go.Figure()
fig.add_annotation(text="PAV data not loaded", showarrow=False)
return fig, True
my_genes = set(pav.index[pav[state.selected_line] == 1])
party_cols = [c for c in state.selected_party if c in pav.columns and c != state.selected_line]
if not party_cols:
fig = go.Figure()
fig.add_annotation(text="No valid party members", showarrow=False)
return fig, True
party_genes = set()
for col in party_cols:
party_genes |= set(pav.index[pav[col] == 1])
shared = len(my_genes & party_genes)
only_mine = len(my_genes - party_genes)
only_party = len(party_genes - my_genes)
fig = go.Figure(data=[
go.Bar(name="Shared", x=["Gene Sets"], y=[shared], marker_color="#2E7D32"),
go.Bar(name=f"Only {state.selected_line}", x=["Gene Sets"], y=[only_mine], marker_color="#1565C0"),
go.Bar(name="Only Party", x=["Gene Sets"], y=[only_party], marker_color="#FFC107"),
])
fig.update_layout(
barmode="group",
title=f"Gene Comparison: {state.selected_line} vs {len(party_cols)} party members",
yaxis_title="Number of genes",
)
return fig, True
# ============================================================
# Quest 2 Callbacks
# ============================================================
def build_donut_chart(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure:
"""Build core/shell/cloud donut chart."""
gene_freq = data["gene_freq"]
core = int((gene_freq["freq_pct"] >= core_thresh).sum())
cloud = int((gene_freq["freq_pct"] < cloud_thresh).sum())
shell = len(gene_freq) - core - cloud
fig = go.Figure(data=[go.Pie(
labels=["Core", "Shell", "Cloud"],
values=[core, shell, cloud],
hole=0.5,
marker_colors=[CORE_COLORS["core"], CORE_COLORS["shell"], CORE_COLORS["cloud"]],
textinfo="label+value+percent",
)])
fig.update_layout(
title=f"Gene Classification (Core>={core_thresh}%, Cloud<{cloud_thresh}%)",
showlegend=True,
)
return fig
def build_frequency_histogram(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure:
"""Build colored histogram of gene frequencies."""
gene_freq = data["gene_freq"]
fig = go.Figure()
for cls, color in CORE_COLORS.items():
if cls == "unknown":
continue
subset = gene_freq[gene_freq["core_class"] == cls]
fig.add_trace(go.Histogram(
x=subset["freq_pct"],
name=cls.capitalize(),
marker_color=color,
opacity=0.75,
nbinsx=50,
))
fig.update_layout(
barmode="overlay",
title="Gene Frequency Distribution",
xaxis_title="Frequency (%)",
yaxis_title="Count",
)
# Add threshold lines
fig.add_vline(x=core_thresh, line_dash="dash", line_color="green",
annotation_text=f"Core>={core_thresh}%")
fig.add_vline(x=cloud_thresh, line_dash="dash", line_color="red",
annotation_text=f"Cloud<{cloud_thresh}%")
return fig
def build_treasure_table(state: AppState, core_thresh: float, cloud_thresh: float,
filter_type: str, data: dict) -> pd.DataFrame:
"""Build gene treasure table with current filters."""
gene_freq = data["gene_freq"].copy()
# Reclassify based on current thresholds
gene_freq["core_class"] = gene_freq["freq_pct"].apply(
lambda x: "core" if x >= core_thresh else ("cloud" if x < cloud_thresh else "shell")
)
# Add in_my_line column
pav = data.get("pav")
if pav is not None and state and state.selected_line and state.selected_line in pav.columns:
my_presence = pav[state.selected_line]
gene_freq["in_my_line"] = gene_freq["gene_id"].map(
lambda g: "Yes" if g in my_presence.index and my_presence.get(g, 0) == 1 else "No"
)
else:
gene_freq["in_my_line"] = "N/A"
# Filter
if filter_type == "Unique to my line":
if pav is not None and state and state.selected_line:
unique_mask = (pav.sum(axis=1) == 1) & (pav[state.selected_line] == 1)
unique_genes = set(pav.index[unique_mask])
gene_freq = gene_freq[gene_freq["gene_id"].isin(unique_genes)]
elif filter_type == "Rare (<5 lines)":
gene_freq = gene_freq[gene_freq["freq_count"] <= 5]
elif filter_type == "Cluster markers":
markers = data.get("markers")
if markers is not None:
marker_genes = set(markers["gene_id"])
gene_freq = gene_freq[gene_freq["gene_id"].isin(marker_genes)]
# Sort and limit
gene_freq = gene_freq.sort_values("freq_count", ascending=True).head(500)
return gene_freq[["gene_id", "freq_count", "freq_pct", "core_class", "in_my_line"]]
def on_pin_gene(gene_id: str, state: AppState) -> tuple:
"""Add gene to backpack."""
if state is None:
state = AppState()
if not gene_id or gene_id == "Click a row to select":
return "Select a gene first", state
added = state.add_to_backpack(gene_id)
backpack_text = ", ".join(state.backpack_genes) if state.backpack_genes else "Empty"
if not added:
backpack_text = f"(already in backpack) {backpack_text}"
return backpack_text, state
def on_gene_click_table(evt, state: AppState) -> tuple:
"""Handle table row selection."""
if state is None:
state = AppState()
if evt is not None and hasattr(evt, 'value'):
gene_id = str(evt.value)
state.selected_gene = gene_id
return gene_id, state
return "Click a row to select", state
# ============================================================
# Quest 3 Callbacks
# ============================================================
def build_hotspot_heatmap(data: dict, top_n_contigs: int = 20) -> go.Figure:
"""Build contig x bin heatmap from hotspot_bins."""
hotspots = data["hotspots"]
# Top N contigs by total genes
contig_counts = hotspots.groupby("contig_id")["total_genes"].sum()
top_contigs = contig_counts.nlargest(top_n_contigs).index.tolist()
subset = hotspots[hotspots["contig_id"].isin(top_contigs)]
if len(subset) == 0:
fig = go.Figure()
fig.add_annotation(text="No hotspot data available", showarrow=False)
return fig
pivot = subset.pivot_table(
index="contig_id", columns="bin_start",
values="variability_score", aggfunc="max"
).fillna(0)
# Shorten contig names for display
short_names = [c.split("|")[-1] if "|" in c else c[:30] for c in pivot.index]
fig = go.Figure(data=go.Heatmap(
z=pivot.values,
x=[f"{int(c/1000)}kb" for c in pivot.columns],
y=short_names,
colorscale=[[0, "#E8F5E9"], [0.5, "#FFC107"], [1.0, "#F44336"]],
colorbar_title="Variability",
hovertemplate="Contig: %{y}<br>Bin: %{x}<br>Score: %{z:.1f}<extra></extra>",
))
fig.update_layout(
title=f"Genomic Variability Heatmap (Top {top_n_contigs} contigs)",
xaxis_title="Genomic position",
yaxis_title="Contig",
height=600,
xaxis=dict(
tickangle=-45,
nticks=20,
tickfont=dict(size=10),
),
margin=dict(b=80),
)
return fig
def on_contig_selected(contig_id: str, data: dict, state: AppState) -> tuple:
"""Build track plot for selected contig."""
if not contig_id:
return go.Figure(), pd.DataFrame()
gff = data["gff_index"]
gene_freq = data["gene_freq"]
contig_genes = gff[gff["contig_id"] == contig_id].merge(
gene_freq[["gene_id", "core_class", "freq_pct"]], on="gene_id", how="left"
)
contig_genes["core_class"] = contig_genes["core_class"].fillna("unknown")
if len(contig_genes) == 0:
fig = go.Figure()
fig.add_annotation(text="No genes on this contig", showarrow=False)
return fig, pd.DataFrame()
fig = go.Figure()
for cls, color in CORE_COLORS.items():
subset = contig_genes[contig_genes["core_class"] == cls]
if len(subset) == 0:
continue
fig.add_trace(go.Scatter(
x=(subset["start"] + subset["end"]) / 2,
y=[cls] * len(subset),
mode="markers",
marker=dict(
symbol="line-ew", size=12, color=color,
line=dict(width=2, color=color),
),
name=cls.capitalize(),
text=subset["gene_id"],
hovertemplate="Gene: %{text}<br>Position: %{x:,.0f}<extra></extra>",
))
short_name = contig_id.split("|")[-1] if "|" in contig_id else contig_id[:30]
fig.update_layout(
title=f"Gene Track: {short_name}",
xaxis_title="Genomic position (bp)",
yaxis_title="Gene class",
showlegend=True,
)
table_df = contig_genes[["gene_id", "start", "end", "strand", "core_class", "freq_pct"]].sort_values("start")
return fig, table_df
# ============================================================
# Quest 4 Callbacks
# ============================================================
def get_protein_stats_html(gene_id: str, data: dict) -> str:
"""Get protein stats as HTML."""
if not gene_id:
return "<p>Select a gene</p>"
protein = data["protein"]
row = protein[protein["gene_id"] == gene_id]
if len(row) == 0:
return "<p><i>No protein data available for this gene.</i></p>"
r = row.iloc[0]
return (
f"<div class='stat-card'>"
f"<p><b>Protein Length:</b> {int(r['protein_length'])} aa</p>"
f"<p><b>Top Amino Acids:</b> {r['composition_summary']}</p>"
f"</div>"
)
def build_backpack_comparison(state: AppState, data: dict) -> go.Figure:
"""Bar chart of protein lengths for backpack genes."""
if not state or len(state.backpack_genes) < 2:
fig = go.Figure()
fig.add_annotation(text="Pin at least 2 genes to compare", showarrow=False)
return fig
protein = data["protein"]
bp_prot = protein[protein["gene_id"].isin(state.backpack_genes)]
fig = go.Figure(data=[go.Bar(
x=bp_prot["gene_id"],
y=bp_prot["protein_length"],
marker_color="#2E7D32",
text=bp_prot["protein_length"],
textposition="auto",
)])
fig.update_layout(
title="Backpack Genes: Protein Length Comparison",
xaxis_title="Gene",
yaxis_title="Protein Length (aa)",
)
return fig
def build_composition_heatmap(state: AppState, data: dict) -> go.Figure:
"""Heatmap of amino acid composition for backpack genes."""
if not state or len(state.backpack_genes) < 2:
fig = go.Figure()
fig.add_annotation(text="Pin at least 2 genes to compare", showarrow=False)
return fig
# Parse composition from summary strings
protein = data["protein"]
bp_prot = protein[protein["gene_id"].isin(state.backpack_genes)]
aa_data = {}
for _, row in bp_prot.iterrows():
gene_id = row["gene_id"]
comp = row["composition_summary"]
aa_dict = {}
for item in comp.split(", "):
parts = item.split(":")
if len(parts) == 2:
aa = parts[0].strip()
pct = float(parts[1].replace("%", ""))
aa_dict[aa] = pct
aa_data[gene_id] = aa_dict
if not aa_data:
fig = go.Figure()
fig.add_annotation(text="No composition data", showarrow=False)
return fig
df = pd.DataFrame(aa_data).fillna(0).T
fig = go.Figure(data=go.Heatmap(
z=df.values,
x=df.columns.tolist(),
y=df.index.tolist(),
colorscale="YlGn",
colorbar_title="%",
))
fig.update_layout(
title="Amino Acid Composition Heatmap",
xaxis_title="Amino Acid",
yaxis_title="Gene",
)
return fig
# ============================================================
# Gene Card Callbacks
# ============================================================
def on_open_gene_card(gene_id: str, state: AppState, data: dict) -> tuple:
"""Open Gene Card side panel."""
if not gene_id:
return "", False, state
state.selected_gene = gene_id
card = build_gene_card(gene_id, data)
html = render_gene_card_html(card)
state.award("Gene Hunter")
return html, True, state
def on_download_gene_report(state: AppState, data: dict) -> str:
"""Download gene report."""
if state and state.selected_gene:
return export_gene_report(state.selected_gene, data)
return None
# ============================================================
# Final Report Callbacks
# ============================================================
def on_generate_report(state: AppState, data: dict) -> tuple:
"""Generate field report."""
if state is None:
state = AppState()
state.award("Cartographer")
report_md = generate_field_report(state, data)
json_path = export_report_json(state, data)
csv_path = export_report_csv(state, data)
# Achievement HTML
badges = " ".join(
f'<span class="achievement-badge">{a}</span>'
for a in sorted(state.achievements)
)
return (
report_md,
gr.File(value=json_path, visible=True),
gr.File(value=csv_path, visible=True),
badges,
state,
)
# ============================================================
# Data Health
# ============================================================
def build_data_health_html(validation_report: dict) -> str:
"""Build data health HTML from validation report."""
rows = ""
for k, v in validation_report.items():
if isinstance(v, float):
v = f"{v:.1f}%"
rows += f"<tr><td><b>{k}</b></td><td>{v}</td></tr>"
return f"<table style='width:100%'>{rows}</table>"
# Need gr import for Tabs update
import gradio as gr