"""All Gradio callbacks for the Pigeon Pea Pangenome Atlas."""
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from src.state import AppState
from src.gene_card import build_gene_card, render_gene_card_html, export_gene_report
from src.field_report import generate_field_report, export_report_json, export_report_csv
# Color palettes
CORE_COLORS = {"core": "#2E7D32", "shell": "#FFC107", "cloud": "#F44336", "unknown": "#9E9E9E"}
COUNTRY_COLORS = px.colors.qualitative.Set3
# ============================================================
# Quest 0 Callbacks
# ============================================================
def on_line_selected(line_id: str, state: AppState, data: dict) -> tuple:
"""
Triggered by dropdown change.
Returns: (total_genes, unique_genes, nearest_neighbor, updated_state)
"""
if not line_id or state is None:
state = AppState()
state.selected_line = line_id
line_stats = data["line_stats"]
similarity = data["similarity"]
row = line_stats[line_stats["line_id"] == line_id]
if len(row) == 0:
return "--", "--", "--", state
total_genes = str(int(row.iloc[0]["genes_present_count"]))
unique_genes = str(int(row.iloc[0]["unique_genes_count"]))
# Nearest neighbor
sim_rows = similarity[similarity["line_id"] == line_id]
if len(sim_rows) > 0:
top = sim_rows.nlargest(1, "jaccard_score").iloc[0]
nearest = f"{top['neighbor_line_id']} ({top['jaccard_score']:.3f})"
else:
nearest = "--"
return total_genes, unique_genes, nearest, state
def on_start_journey(state: AppState) -> tuple:
"""Award Explorer achievement and switch to Quest 1."""
if state is None:
state = AppState()
state.award("Explorer")
return gr.Tabs(selected="quest1"), state
# ============================================================
# Quest 1 Callbacks
# ============================================================
def build_umap_plot(color_by: str, state: AppState, data: dict) -> go.Figure:
"""Build 3D UMAP scatter (delegates to quest1.build_umap_3d)."""
from ui.quest1 import build_umap_3d
selected_line = state.selected_line if state else None
color_key = "country" if color_by == "Country" else "cluster"
return build_umap_3d(
data["embedding"], data["line_stats"],
color_by=color_key, selected_line=selected_line,
)
def on_umap_select(selected_data, state: AppState) -> tuple:
"""Handle UMAP point selection."""
if state is None:
state = AppState()
if selected_data and "points" in selected_data:
selected_lines = [p.get("hovertext", p.get("text", "")) for p in selected_data["points"]]
selected_lines = [l for l in selected_lines if l]
state.selected_party = selected_lines
party_text = f"Selected {len(selected_lines)} lines: " + ", ".join(selected_lines[:10])
if len(selected_lines) > 10:
party_text += f" ... +{len(selected_lines) - 10} more"
else:
state.selected_party = []
party_text = "None selected"
return party_text, state
def on_compare_party(state: AppState, data: dict) -> tuple:
"""Compare selected line vs party."""
if not state or not state.selected_line or not state.selected_party:
fig = go.Figure()
fig.add_annotation(text="Select your line and a party first", showarrow=False)
return fig, True
pav = data.get("pav")
if pav is None:
fig = go.Figure()
fig.add_annotation(text="PAV data not loaded", showarrow=False)
return fig, True
my_genes = set(pav.index[pav[state.selected_line] == 1])
party_cols = [c for c in state.selected_party if c in pav.columns and c != state.selected_line]
if not party_cols:
fig = go.Figure()
fig.add_annotation(text="No valid party members", showarrow=False)
return fig, True
party_genes = set()
for col in party_cols:
party_genes |= set(pav.index[pav[col] == 1])
shared = len(my_genes & party_genes)
only_mine = len(my_genes - party_genes)
only_party = len(party_genes - my_genes)
fig = go.Figure(data=[
go.Bar(name="Shared", x=["Gene Sets"], y=[shared], marker_color="#2E7D32"),
go.Bar(name=f"Only {state.selected_line}", x=["Gene Sets"], y=[only_mine], marker_color="#1565C0"),
go.Bar(name="Only Party", x=["Gene Sets"], y=[only_party], marker_color="#FFC107"),
])
fig.update_layout(
barmode="group",
title=f"Gene Comparison: {state.selected_line} vs {len(party_cols)} party members",
yaxis_title="Number of genes",
)
return fig, True
# ============================================================
# Quest 2 Callbacks
# ============================================================
def build_donut_chart(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure:
"""Build core/shell/cloud donut chart."""
gene_freq = data["gene_freq"]
core = int((gene_freq["freq_pct"] >= core_thresh).sum())
cloud = int((gene_freq["freq_pct"] < cloud_thresh).sum())
shell = len(gene_freq) - core - cloud
fig = go.Figure(data=[go.Pie(
labels=["Core", "Shell", "Cloud"],
values=[core, shell, cloud],
hole=0.5,
marker_colors=[CORE_COLORS["core"], CORE_COLORS["shell"], CORE_COLORS["cloud"]],
textinfo="label+value+percent",
)])
fig.update_layout(
title=f"Gene Classification (Core>={core_thresh}%, Cloud<{cloud_thresh}%)",
showlegend=True,
)
return fig
def build_frequency_histogram(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure:
"""Build colored histogram of gene frequencies."""
gene_freq = data["gene_freq"]
fig = go.Figure()
for cls, color in CORE_COLORS.items():
if cls == "unknown":
continue
subset = gene_freq[gene_freq["core_class"] == cls]
fig.add_trace(go.Histogram(
x=subset["freq_pct"],
name=cls.capitalize(),
marker_color=color,
opacity=0.75,
nbinsx=50,
))
fig.update_layout(
barmode="overlay",
title="Gene Frequency Distribution",
xaxis_title="Frequency (%)",
yaxis_title="Count",
)
# Add threshold lines
fig.add_vline(x=core_thresh, line_dash="dash", line_color="green",
annotation_text=f"Core>={core_thresh}%")
fig.add_vline(x=cloud_thresh, line_dash="dash", line_color="red",
annotation_text=f"Cloud<{cloud_thresh}%")
return fig
def build_treasure_table(state: AppState, core_thresh: float, cloud_thresh: float,
filter_type: str, data: dict) -> pd.DataFrame:
"""Build gene treasure table with current filters."""
gene_freq = data["gene_freq"].copy()
# Reclassify based on current thresholds
gene_freq["core_class"] = gene_freq["freq_pct"].apply(
lambda x: "core" if x >= core_thresh else ("cloud" if x < cloud_thresh else "shell")
)
# Add in_my_line column
pav = data.get("pav")
if pav is not None and state and state.selected_line and state.selected_line in pav.columns:
my_presence = pav[state.selected_line]
gene_freq["in_my_line"] = gene_freq["gene_id"].map(
lambda g: "Yes" if g in my_presence.index and my_presence.get(g, 0) == 1 else "No"
)
else:
gene_freq["in_my_line"] = "N/A"
# Filter
if filter_type == "Unique to my line":
if pav is not None and state and state.selected_line:
unique_mask = (pav.sum(axis=1) == 1) & (pav[state.selected_line] == 1)
unique_genes = set(pav.index[unique_mask])
gene_freq = gene_freq[gene_freq["gene_id"].isin(unique_genes)]
elif filter_type == "Rare (<5 lines)":
gene_freq = gene_freq[gene_freq["freq_count"] <= 5]
elif filter_type == "Cluster markers":
markers = data.get("markers")
if markers is not None:
marker_genes = set(markers["gene_id"])
gene_freq = gene_freq[gene_freq["gene_id"].isin(marker_genes)]
# Sort and limit
gene_freq = gene_freq.sort_values("freq_count", ascending=True).head(500)
return gene_freq[["gene_id", "freq_count", "freq_pct", "core_class", "in_my_line"]]
def on_pin_gene(gene_id: str, state: AppState) -> tuple:
"""Add gene to backpack."""
if state is None:
state = AppState()
if not gene_id or gene_id == "Click a row to select":
return "Select a gene first", state
added = state.add_to_backpack(gene_id)
backpack_text = ", ".join(state.backpack_genes) if state.backpack_genes else "Empty"
if not added:
backpack_text = f"(already in backpack) {backpack_text}"
return backpack_text, state
def on_gene_click_table(evt, state: AppState) -> tuple:
"""Handle table row selection."""
if state is None:
state = AppState()
if evt is not None and hasattr(evt, 'value'):
gene_id = str(evt.value)
state.selected_gene = gene_id
return gene_id, state
return "Click a row to select", state
# ============================================================
# Quest 3 Callbacks
# ============================================================
def build_hotspot_heatmap(data: dict, top_n_contigs: int = 20) -> go.Figure:
"""Build contig x bin heatmap from hotspot_bins."""
hotspots = data["hotspots"]
# Top N contigs by total genes
contig_counts = hotspots.groupby("contig_id")["total_genes"].sum()
top_contigs = contig_counts.nlargest(top_n_contigs).index.tolist()
subset = hotspots[hotspots["contig_id"].isin(top_contigs)]
if len(subset) == 0:
fig = go.Figure()
fig.add_annotation(text="No hotspot data available", showarrow=False)
return fig
pivot = subset.pivot_table(
index="contig_id", columns="bin_start",
values="variability_score", aggfunc="max"
).fillna(0)
# Shorten contig names for display
short_names = [c.split("|")[-1] if "|" in c else c[:30] for c in pivot.index]
fig = go.Figure(data=go.Heatmap(
z=pivot.values,
x=[f"{int(c/1000)}kb" for c in pivot.columns],
y=short_names,
colorscale=[[0, "#E8F5E9"], [0.5, "#FFC107"], [1.0, "#F44336"]],
colorbar_title="Variability",
hovertemplate="Contig: %{y}
Bin: %{x}
Score: %{z:.1f}
Position: %{x:,.0f}
Select a gene
" protein = data["protein"] row = protein[protein["gene_id"] == gene_id] if len(row) == 0: return "No protein data available for this gene.
" r = row.iloc[0] return ( f"Protein Length: {int(r['protein_length'])} aa
" f"Top Amino Acids: {r['composition_summary']}
" f"