"""All Gradio callbacks for the Pigeon Pea Pangenome Atlas.""" import plotly.graph_objects as go import plotly.express as px import pandas as pd import numpy as np from src.state import AppState from src.gene_card import build_gene_card, render_gene_card_html, export_gene_report from src.field_report import generate_field_report, export_report_json, export_report_csv # Color palettes CORE_COLORS = {"core": "#2E7D32", "shell": "#FFC107", "cloud": "#F44336", "unknown": "#9E9E9E"} COUNTRY_COLORS = px.colors.qualitative.Set3 # ============================================================ # Quest 0 Callbacks # ============================================================ def on_line_selected(line_id: str, state: AppState, data: dict) -> tuple: """ Triggered by dropdown change. Returns: (total_genes, unique_genes, nearest_neighbor, updated_state) """ if not line_id or state is None: state = AppState() state.selected_line = line_id line_stats = data["line_stats"] similarity = data["similarity"] row = line_stats[line_stats["line_id"] == line_id] if len(row) == 0: return "--", "--", "--", state total_genes = str(int(row.iloc[0]["genes_present_count"])) unique_genes = str(int(row.iloc[0]["unique_genes_count"])) # Nearest neighbor sim_rows = similarity[similarity["line_id"] == line_id] if len(sim_rows) > 0: top = sim_rows.nlargest(1, "jaccard_score").iloc[0] nearest = f"{top['neighbor_line_id']} ({top['jaccard_score']:.3f})" else: nearest = "--" return total_genes, unique_genes, nearest, state def on_start_journey(state: AppState) -> tuple: """Award Explorer achievement and switch to Quest 1.""" if state is None: state = AppState() state.award("Explorer") return gr.Tabs(selected="quest1"), state # ============================================================ # Quest 1 Callbacks # ============================================================ def build_umap_plot(color_by: str, state: AppState, data: dict) -> go.Figure: """Build 3D UMAP scatter (delegates to quest1.build_umap_3d).""" from ui.quest1 import build_umap_3d selected_line = state.selected_line if state else None color_key = "country" if color_by == "Country" else "cluster" return build_umap_3d( data["embedding"], data["line_stats"], color_by=color_key, selected_line=selected_line, ) def on_umap_select(selected_data, state: AppState) -> tuple: """Handle UMAP point selection.""" if state is None: state = AppState() if selected_data and "points" in selected_data: selected_lines = [p.get("hovertext", p.get("text", "")) for p in selected_data["points"]] selected_lines = [l for l in selected_lines if l] state.selected_party = selected_lines party_text = f"Selected {len(selected_lines)} lines: " + ", ".join(selected_lines[:10]) if len(selected_lines) > 10: party_text += f" ... +{len(selected_lines) - 10} more" else: state.selected_party = [] party_text = "None selected" return party_text, state def on_compare_party(state: AppState, data: dict) -> tuple: """Compare selected line vs party.""" if not state or not state.selected_line or not state.selected_party: fig = go.Figure() fig.add_annotation(text="Select your line and a party first", showarrow=False) return fig, True pav = data.get("pav") if pav is None: fig = go.Figure() fig.add_annotation(text="PAV data not loaded", showarrow=False) return fig, True my_genes = set(pav.index[pav[state.selected_line] == 1]) party_cols = [c for c in state.selected_party if c in pav.columns and c != state.selected_line] if not party_cols: fig = go.Figure() fig.add_annotation(text="No valid party members", showarrow=False) return fig, True party_genes = set() for col in party_cols: party_genes |= set(pav.index[pav[col] == 1]) shared = len(my_genes & party_genes) only_mine = len(my_genes - party_genes) only_party = len(party_genes - my_genes) fig = go.Figure(data=[ go.Bar(name="Shared", x=["Gene Sets"], y=[shared], marker_color="#2E7D32"), go.Bar(name=f"Only {state.selected_line}", x=["Gene Sets"], y=[only_mine], marker_color="#1565C0"), go.Bar(name="Only Party", x=["Gene Sets"], y=[only_party], marker_color="#FFC107"), ]) fig.update_layout( barmode="group", title=f"Gene Comparison: {state.selected_line} vs {len(party_cols)} party members", yaxis_title="Number of genes", ) return fig, True # ============================================================ # Quest 2 Callbacks # ============================================================ def build_donut_chart(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure: """Build core/shell/cloud donut chart.""" gene_freq = data["gene_freq"] core = int((gene_freq["freq_pct"] >= core_thresh).sum()) cloud = int((gene_freq["freq_pct"] < cloud_thresh).sum()) shell = len(gene_freq) - core - cloud fig = go.Figure(data=[go.Pie( labels=["Core", "Shell", "Cloud"], values=[core, shell, cloud], hole=0.5, marker_colors=[CORE_COLORS["core"], CORE_COLORS["shell"], CORE_COLORS["cloud"]], textinfo="label+value+percent", )]) fig.update_layout( title=f"Gene Classification (Core>={core_thresh}%, Cloud<{cloud_thresh}%)", showlegend=True, ) return fig def build_frequency_histogram(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure: """Build colored histogram of gene frequencies.""" gene_freq = data["gene_freq"] fig = go.Figure() for cls, color in CORE_COLORS.items(): if cls == "unknown": continue subset = gene_freq[gene_freq["core_class"] == cls] fig.add_trace(go.Histogram( x=subset["freq_pct"], name=cls.capitalize(), marker_color=color, opacity=0.75, nbinsx=50, )) fig.update_layout( barmode="overlay", title="Gene Frequency Distribution", xaxis_title="Frequency (%)", yaxis_title="Count", ) # Add threshold lines fig.add_vline(x=core_thresh, line_dash="dash", line_color="green", annotation_text=f"Core>={core_thresh}%") fig.add_vline(x=cloud_thresh, line_dash="dash", line_color="red", annotation_text=f"Cloud<{cloud_thresh}%") return fig def build_treasure_table(state: AppState, core_thresh: float, cloud_thresh: float, filter_type: str, data: dict) -> pd.DataFrame: """Build gene treasure table with current filters.""" gene_freq = data["gene_freq"].copy() # Reclassify based on current thresholds gene_freq["core_class"] = gene_freq["freq_pct"].apply( lambda x: "core" if x >= core_thresh else ("cloud" if x < cloud_thresh else "shell") ) # Add in_my_line column pav = data.get("pav") if pav is not None and state and state.selected_line and state.selected_line in pav.columns: my_presence = pav[state.selected_line] gene_freq["in_my_line"] = gene_freq["gene_id"].map( lambda g: "Yes" if g in my_presence.index and my_presence.get(g, 0) == 1 else "No" ) else: gene_freq["in_my_line"] = "N/A" # Filter if filter_type == "Unique to my line": if pav is not None and state and state.selected_line: unique_mask = (pav.sum(axis=1) == 1) & (pav[state.selected_line] == 1) unique_genes = set(pav.index[unique_mask]) gene_freq = gene_freq[gene_freq["gene_id"].isin(unique_genes)] elif filter_type == "Rare (<5 lines)": gene_freq = gene_freq[gene_freq["freq_count"] <= 5] elif filter_type == "Cluster markers": markers = data.get("markers") if markers is not None: marker_genes = set(markers["gene_id"]) gene_freq = gene_freq[gene_freq["gene_id"].isin(marker_genes)] # Sort and limit gene_freq = gene_freq.sort_values("freq_count", ascending=True).head(500) return gene_freq[["gene_id", "freq_count", "freq_pct", "core_class", "in_my_line"]] def on_pin_gene(gene_id: str, state: AppState) -> tuple: """Add gene to backpack.""" if state is None: state = AppState() if not gene_id or gene_id == "Click a row to select": return "Select a gene first", state added = state.add_to_backpack(gene_id) backpack_text = ", ".join(state.backpack_genes) if state.backpack_genes else "Empty" if not added: backpack_text = f"(already in backpack) {backpack_text}" return backpack_text, state def on_gene_click_table(evt, state: AppState) -> tuple: """Handle table row selection.""" if state is None: state = AppState() if evt is not None and hasattr(evt, 'value'): gene_id = str(evt.value) state.selected_gene = gene_id return gene_id, state return "Click a row to select", state # ============================================================ # Quest 3 Callbacks # ============================================================ def build_hotspot_heatmap(data: dict, top_n_contigs: int = 20) -> go.Figure: """Build contig x bin heatmap from hotspot_bins.""" hotspots = data["hotspots"] # Top N contigs by total genes contig_counts = hotspots.groupby("contig_id")["total_genes"].sum() top_contigs = contig_counts.nlargest(top_n_contigs).index.tolist() subset = hotspots[hotspots["contig_id"].isin(top_contigs)] if len(subset) == 0: fig = go.Figure() fig.add_annotation(text="No hotspot data available", showarrow=False) return fig pivot = subset.pivot_table( index="contig_id", columns="bin_start", values="variability_score", aggfunc="max" ).fillna(0) # Shorten contig names for display short_names = [c.split("|")[-1] if "|" in c else c[:30] for c in pivot.index] fig = go.Figure(data=go.Heatmap( z=pivot.values, x=[f"{int(c/1000)}kb" for c in pivot.columns], y=short_names, colorscale=[[0, "#E8F5E9"], [0.5, "#FFC107"], [1.0, "#F44336"]], colorbar_title="Variability", hovertemplate="Contig: %{y}
Bin: %{x}
Score: %{z:.1f}", )) fig.update_layout( title=f"Genomic Variability Heatmap (Top {top_n_contigs} contigs)", xaxis_title="Genomic position", yaxis_title="Contig", height=600, xaxis=dict( tickangle=-45, nticks=20, tickfont=dict(size=10), ), margin=dict(b=80), ) return fig def on_contig_selected(contig_id: str, data: dict, state: AppState) -> tuple: """Build track plot for selected contig.""" if not contig_id: return go.Figure(), pd.DataFrame() gff = data["gff_index"] gene_freq = data["gene_freq"] contig_genes = gff[gff["contig_id"] == contig_id].merge( gene_freq[["gene_id", "core_class", "freq_pct"]], on="gene_id", how="left" ) contig_genes["core_class"] = contig_genes["core_class"].fillna("unknown") if len(contig_genes) == 0: fig = go.Figure() fig.add_annotation(text="No genes on this contig", showarrow=False) return fig, pd.DataFrame() fig = go.Figure() for cls, color in CORE_COLORS.items(): subset = contig_genes[contig_genes["core_class"] == cls] if len(subset) == 0: continue fig.add_trace(go.Scatter( x=(subset["start"] + subset["end"]) / 2, y=[cls] * len(subset), mode="markers", marker=dict( symbol="line-ew", size=12, color=color, line=dict(width=2, color=color), ), name=cls.capitalize(), text=subset["gene_id"], hovertemplate="Gene: %{text}
Position: %{x:,.0f}", )) short_name = contig_id.split("|")[-1] if "|" in contig_id else contig_id[:30] fig.update_layout( title=f"Gene Track: {short_name}", xaxis_title="Genomic position (bp)", yaxis_title="Gene class", showlegend=True, ) table_df = contig_genes[["gene_id", "start", "end", "strand", "core_class", "freq_pct"]].sort_values("start") return fig, table_df # ============================================================ # Quest 4 Callbacks # ============================================================ def get_protein_stats_html(gene_id: str, data: dict) -> str: """Get protein stats as HTML.""" if not gene_id: return "

Select a gene

" protein = data["protein"] row = protein[protein["gene_id"] == gene_id] if len(row) == 0: return "

No protein data available for this gene.

" r = row.iloc[0] return ( f"
" f"

Protein Length: {int(r['protein_length'])} aa

" f"

Top Amino Acids: {r['composition_summary']}

" f"
" ) def build_backpack_comparison(state: AppState, data: dict) -> go.Figure: """Bar chart of protein lengths for backpack genes.""" if not state or len(state.backpack_genes) < 2: fig = go.Figure() fig.add_annotation(text="Pin at least 2 genes to compare", showarrow=False) return fig protein = data["protein"] bp_prot = protein[protein["gene_id"].isin(state.backpack_genes)] fig = go.Figure(data=[go.Bar( x=bp_prot["gene_id"], y=bp_prot["protein_length"], marker_color="#2E7D32", text=bp_prot["protein_length"], textposition="auto", )]) fig.update_layout( title="Backpack Genes: Protein Length Comparison", xaxis_title="Gene", yaxis_title="Protein Length (aa)", ) return fig def build_composition_heatmap(state: AppState, data: dict) -> go.Figure: """Heatmap of amino acid composition for backpack genes.""" if not state or len(state.backpack_genes) < 2: fig = go.Figure() fig.add_annotation(text="Pin at least 2 genes to compare", showarrow=False) return fig # Parse composition from summary strings protein = data["protein"] bp_prot = protein[protein["gene_id"].isin(state.backpack_genes)] aa_data = {} for _, row in bp_prot.iterrows(): gene_id = row["gene_id"] comp = row["composition_summary"] aa_dict = {} for item in comp.split(", "): parts = item.split(":") if len(parts) == 2: aa = parts[0].strip() pct = float(parts[1].replace("%", "")) aa_dict[aa] = pct aa_data[gene_id] = aa_dict if not aa_data: fig = go.Figure() fig.add_annotation(text="No composition data", showarrow=False) return fig df = pd.DataFrame(aa_data).fillna(0).T fig = go.Figure(data=go.Heatmap( z=df.values, x=df.columns.tolist(), y=df.index.tolist(), colorscale="YlGn", colorbar_title="%", )) fig.update_layout( title="Amino Acid Composition Heatmap", xaxis_title="Amino Acid", yaxis_title="Gene", ) return fig # ============================================================ # Gene Card Callbacks # ============================================================ def on_open_gene_card(gene_id: str, state: AppState, data: dict) -> tuple: """Open Gene Card side panel.""" if not gene_id: return "", False, state state.selected_gene = gene_id card = build_gene_card(gene_id, data) html = render_gene_card_html(card) state.award("Gene Hunter") return html, True, state def on_download_gene_report(state: AppState, data: dict) -> str: """Download gene report.""" if state and state.selected_gene: return export_gene_report(state.selected_gene, data) return None # ============================================================ # Final Report Callbacks # ============================================================ def on_generate_report(state: AppState, data: dict) -> tuple: """Generate field report.""" if state is None: state = AppState() state.award("Cartographer") report_md = generate_field_report(state, data) json_path = export_report_json(state, data) csv_path = export_report_csv(state, data) # Achievement HTML badges = " ".join( f'{a}' for a in sorted(state.achievements) ) return ( report_md, gr.File(value=json_path, visible=True), gr.File(value=csv_path, visible=True), badges, state, ) # ============================================================ # Data Health # ============================================================ def build_data_health_html(validation_report: dict) -> str: """Build data health HTML from validation report.""" rows = "" for k, v in validation_report.items(): if isinstance(v, float): v = f"{v:.1f}%" rows += f"{k}{v}" return f"{rows}
" # Need gr import for Tabs update import gradio as gr