Spaces:
Sleeping
Sleeping
Ashkan Taghipour (The University of Western Australia)
Improve Genome Explorer heatmap: angle x-axis labels, limit tick count
1e42f44 | """All Gradio callbacks for the Pigeon Pea Pangenome Atlas.""" | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import pandas as pd | |
| import numpy as np | |
| from src.state import AppState | |
| from src.gene_card import build_gene_card, render_gene_card_html, export_gene_report | |
| from src.field_report import generate_field_report, export_report_json, export_report_csv | |
| # Color palettes | |
| CORE_COLORS = {"core": "#2E7D32", "shell": "#FFC107", "cloud": "#F44336", "unknown": "#9E9E9E"} | |
| COUNTRY_COLORS = px.colors.qualitative.Set3 | |
| # ============================================================ | |
| # Quest 0 Callbacks | |
| # ============================================================ | |
| def on_line_selected(line_id: str, state: AppState, data: dict) -> tuple: | |
| """ | |
| Triggered by dropdown change. | |
| Returns: (total_genes, unique_genes, nearest_neighbor, updated_state) | |
| """ | |
| if not line_id or state is None: | |
| state = AppState() | |
| state.selected_line = line_id | |
| line_stats = data["line_stats"] | |
| similarity = data["similarity"] | |
| row = line_stats[line_stats["line_id"] == line_id] | |
| if len(row) == 0: | |
| return "--", "--", "--", state | |
| total_genes = str(int(row.iloc[0]["genes_present_count"])) | |
| unique_genes = str(int(row.iloc[0]["unique_genes_count"])) | |
| # Nearest neighbor | |
| sim_rows = similarity[similarity["line_id"] == line_id] | |
| if len(sim_rows) > 0: | |
| top = sim_rows.nlargest(1, "jaccard_score").iloc[0] | |
| nearest = f"{top['neighbor_line_id']} ({top['jaccard_score']:.3f})" | |
| else: | |
| nearest = "--" | |
| return total_genes, unique_genes, nearest, state | |
| def on_start_journey(state: AppState) -> tuple: | |
| """Award Explorer achievement and switch to Quest 1.""" | |
| if state is None: | |
| state = AppState() | |
| state.award("Explorer") | |
| return gr.Tabs(selected="quest1"), state | |
| # ============================================================ | |
| # Quest 1 Callbacks | |
| # ============================================================ | |
| def build_umap_plot(color_by: str, state: AppState, data: dict) -> go.Figure: | |
| """Build 3D UMAP scatter (delegates to quest1.build_umap_3d).""" | |
| from ui.quest1 import build_umap_3d | |
| selected_line = state.selected_line if state else None | |
| color_key = "country" if color_by == "Country" else "cluster" | |
| return build_umap_3d( | |
| data["embedding"], data["line_stats"], | |
| color_by=color_key, selected_line=selected_line, | |
| ) | |
| def on_umap_select(selected_data, state: AppState) -> tuple: | |
| """Handle UMAP point selection.""" | |
| if state is None: | |
| state = AppState() | |
| if selected_data and "points" in selected_data: | |
| selected_lines = [p.get("hovertext", p.get("text", "")) for p in selected_data["points"]] | |
| selected_lines = [l for l in selected_lines if l] | |
| state.selected_party = selected_lines | |
| party_text = f"Selected {len(selected_lines)} lines: " + ", ".join(selected_lines[:10]) | |
| if len(selected_lines) > 10: | |
| party_text += f" ... +{len(selected_lines) - 10} more" | |
| else: | |
| state.selected_party = [] | |
| party_text = "None selected" | |
| return party_text, state | |
| def on_compare_party(state: AppState, data: dict) -> tuple: | |
| """Compare selected line vs party.""" | |
| if not state or not state.selected_line or not state.selected_party: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Select your line and a party first", showarrow=False) | |
| return fig, True | |
| pav = data.get("pav") | |
| if pav is None: | |
| fig = go.Figure() | |
| fig.add_annotation(text="PAV data not loaded", showarrow=False) | |
| return fig, True | |
| my_genes = set(pav.index[pav[state.selected_line] == 1]) | |
| party_cols = [c for c in state.selected_party if c in pav.columns and c != state.selected_line] | |
| if not party_cols: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No valid party members", showarrow=False) | |
| return fig, True | |
| party_genes = set() | |
| for col in party_cols: | |
| party_genes |= set(pav.index[pav[col] == 1]) | |
| shared = len(my_genes & party_genes) | |
| only_mine = len(my_genes - party_genes) | |
| only_party = len(party_genes - my_genes) | |
| fig = go.Figure(data=[ | |
| go.Bar(name="Shared", x=["Gene Sets"], y=[shared], marker_color="#2E7D32"), | |
| go.Bar(name=f"Only {state.selected_line}", x=["Gene Sets"], y=[only_mine], marker_color="#1565C0"), | |
| go.Bar(name="Only Party", x=["Gene Sets"], y=[only_party], marker_color="#FFC107"), | |
| ]) | |
| fig.update_layout( | |
| barmode="group", | |
| title=f"Gene Comparison: {state.selected_line} vs {len(party_cols)} party members", | |
| yaxis_title="Number of genes", | |
| ) | |
| return fig, True | |
| # ============================================================ | |
| # Quest 2 Callbacks | |
| # ============================================================ | |
| def build_donut_chart(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure: | |
| """Build core/shell/cloud donut chart.""" | |
| gene_freq = data["gene_freq"] | |
| core = int((gene_freq["freq_pct"] >= core_thresh).sum()) | |
| cloud = int((gene_freq["freq_pct"] < cloud_thresh).sum()) | |
| shell = len(gene_freq) - core - cloud | |
| fig = go.Figure(data=[go.Pie( | |
| labels=["Core", "Shell", "Cloud"], | |
| values=[core, shell, cloud], | |
| hole=0.5, | |
| marker_colors=[CORE_COLORS["core"], CORE_COLORS["shell"], CORE_COLORS["cloud"]], | |
| textinfo="label+value+percent", | |
| )]) | |
| fig.update_layout( | |
| title=f"Gene Classification (Core>={core_thresh}%, Cloud<{cloud_thresh}%)", | |
| showlegend=True, | |
| ) | |
| return fig | |
| def build_frequency_histogram(core_thresh: float, cloud_thresh: float, data: dict) -> go.Figure: | |
| """Build colored histogram of gene frequencies.""" | |
| gene_freq = data["gene_freq"] | |
| fig = go.Figure() | |
| for cls, color in CORE_COLORS.items(): | |
| if cls == "unknown": | |
| continue | |
| subset = gene_freq[gene_freq["core_class"] == cls] | |
| fig.add_trace(go.Histogram( | |
| x=subset["freq_pct"], | |
| name=cls.capitalize(), | |
| marker_color=color, | |
| opacity=0.75, | |
| nbinsx=50, | |
| )) | |
| fig.update_layout( | |
| barmode="overlay", | |
| title="Gene Frequency Distribution", | |
| xaxis_title="Frequency (%)", | |
| yaxis_title="Count", | |
| ) | |
| # Add threshold lines | |
| fig.add_vline(x=core_thresh, line_dash="dash", line_color="green", | |
| annotation_text=f"Core>={core_thresh}%") | |
| fig.add_vline(x=cloud_thresh, line_dash="dash", line_color="red", | |
| annotation_text=f"Cloud<{cloud_thresh}%") | |
| return fig | |
| def build_treasure_table(state: AppState, core_thresh: float, cloud_thresh: float, | |
| filter_type: str, data: dict) -> pd.DataFrame: | |
| """Build gene treasure table with current filters.""" | |
| gene_freq = data["gene_freq"].copy() | |
| # Reclassify based on current thresholds | |
| gene_freq["core_class"] = gene_freq["freq_pct"].apply( | |
| lambda x: "core" if x >= core_thresh else ("cloud" if x < cloud_thresh else "shell") | |
| ) | |
| # Add in_my_line column | |
| pav = data.get("pav") | |
| if pav is not None and state and state.selected_line and state.selected_line in pav.columns: | |
| my_presence = pav[state.selected_line] | |
| gene_freq["in_my_line"] = gene_freq["gene_id"].map( | |
| lambda g: "Yes" if g in my_presence.index and my_presence.get(g, 0) == 1 else "No" | |
| ) | |
| else: | |
| gene_freq["in_my_line"] = "N/A" | |
| # Filter | |
| if filter_type == "Unique to my line": | |
| if pav is not None and state and state.selected_line: | |
| unique_mask = (pav.sum(axis=1) == 1) & (pav[state.selected_line] == 1) | |
| unique_genes = set(pav.index[unique_mask]) | |
| gene_freq = gene_freq[gene_freq["gene_id"].isin(unique_genes)] | |
| elif filter_type == "Rare (<5 lines)": | |
| gene_freq = gene_freq[gene_freq["freq_count"] <= 5] | |
| elif filter_type == "Cluster markers": | |
| markers = data.get("markers") | |
| if markers is not None: | |
| marker_genes = set(markers["gene_id"]) | |
| gene_freq = gene_freq[gene_freq["gene_id"].isin(marker_genes)] | |
| # Sort and limit | |
| gene_freq = gene_freq.sort_values("freq_count", ascending=True).head(500) | |
| return gene_freq[["gene_id", "freq_count", "freq_pct", "core_class", "in_my_line"]] | |
| def on_pin_gene(gene_id: str, state: AppState) -> tuple: | |
| """Add gene to backpack.""" | |
| if state is None: | |
| state = AppState() | |
| if not gene_id or gene_id == "Click a row to select": | |
| return "Select a gene first", state | |
| added = state.add_to_backpack(gene_id) | |
| backpack_text = ", ".join(state.backpack_genes) if state.backpack_genes else "Empty" | |
| if not added: | |
| backpack_text = f"(already in backpack) {backpack_text}" | |
| return backpack_text, state | |
| def on_gene_click_table(evt, state: AppState) -> tuple: | |
| """Handle table row selection.""" | |
| if state is None: | |
| state = AppState() | |
| if evt is not None and hasattr(evt, 'value'): | |
| gene_id = str(evt.value) | |
| state.selected_gene = gene_id | |
| return gene_id, state | |
| return "Click a row to select", state | |
| # ============================================================ | |
| # Quest 3 Callbacks | |
| # ============================================================ | |
| def build_hotspot_heatmap(data: dict, top_n_contigs: int = 20) -> go.Figure: | |
| """Build contig x bin heatmap from hotspot_bins.""" | |
| hotspots = data["hotspots"] | |
| # Top N contigs by total genes | |
| contig_counts = hotspots.groupby("contig_id")["total_genes"].sum() | |
| top_contigs = contig_counts.nlargest(top_n_contigs).index.tolist() | |
| subset = hotspots[hotspots["contig_id"].isin(top_contigs)] | |
| if len(subset) == 0: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No hotspot data available", showarrow=False) | |
| return fig | |
| pivot = subset.pivot_table( | |
| index="contig_id", columns="bin_start", | |
| values="variability_score", aggfunc="max" | |
| ).fillna(0) | |
| # Shorten contig names for display | |
| short_names = [c.split("|")[-1] if "|" in c else c[:30] for c in pivot.index] | |
| fig = go.Figure(data=go.Heatmap( | |
| z=pivot.values, | |
| x=[f"{int(c/1000)}kb" for c in pivot.columns], | |
| y=short_names, | |
| colorscale=[[0, "#E8F5E9"], [0.5, "#FFC107"], [1.0, "#F44336"]], | |
| colorbar_title="Variability", | |
| hovertemplate="Contig: %{y}<br>Bin: %{x}<br>Score: %{z:.1f}<extra></extra>", | |
| )) | |
| fig.update_layout( | |
| title=f"Genomic Variability Heatmap (Top {top_n_contigs} contigs)", | |
| xaxis_title="Genomic position", | |
| yaxis_title="Contig", | |
| height=600, | |
| xaxis=dict( | |
| tickangle=-45, | |
| nticks=20, | |
| tickfont=dict(size=10), | |
| ), | |
| margin=dict(b=80), | |
| ) | |
| return fig | |
| def on_contig_selected(contig_id: str, data: dict, state: AppState) -> tuple: | |
| """Build track plot for selected contig.""" | |
| if not contig_id: | |
| return go.Figure(), pd.DataFrame() | |
| gff = data["gff_index"] | |
| gene_freq = data["gene_freq"] | |
| contig_genes = gff[gff["contig_id"] == contig_id].merge( | |
| gene_freq[["gene_id", "core_class", "freq_pct"]], on="gene_id", how="left" | |
| ) | |
| contig_genes["core_class"] = contig_genes["core_class"].fillna("unknown") | |
| if len(contig_genes) == 0: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No genes on this contig", showarrow=False) | |
| return fig, pd.DataFrame() | |
| fig = go.Figure() | |
| for cls, color in CORE_COLORS.items(): | |
| subset = contig_genes[contig_genes["core_class"] == cls] | |
| if len(subset) == 0: | |
| continue | |
| fig.add_trace(go.Scatter( | |
| x=(subset["start"] + subset["end"]) / 2, | |
| y=[cls] * len(subset), | |
| mode="markers", | |
| marker=dict( | |
| symbol="line-ew", size=12, color=color, | |
| line=dict(width=2, color=color), | |
| ), | |
| name=cls.capitalize(), | |
| text=subset["gene_id"], | |
| hovertemplate="Gene: %{text}<br>Position: %{x:,.0f}<extra></extra>", | |
| )) | |
| short_name = contig_id.split("|")[-1] if "|" in contig_id else contig_id[:30] | |
| fig.update_layout( | |
| title=f"Gene Track: {short_name}", | |
| xaxis_title="Genomic position (bp)", | |
| yaxis_title="Gene class", | |
| showlegend=True, | |
| ) | |
| table_df = contig_genes[["gene_id", "start", "end", "strand", "core_class", "freq_pct"]].sort_values("start") | |
| return fig, table_df | |
| # ============================================================ | |
| # Quest 4 Callbacks | |
| # ============================================================ | |
| def get_protein_stats_html(gene_id: str, data: dict) -> str: | |
| """Get protein stats as HTML.""" | |
| if not gene_id: | |
| return "<p>Select a gene</p>" | |
| protein = data["protein"] | |
| row = protein[protein["gene_id"] == gene_id] | |
| if len(row) == 0: | |
| return "<p><i>No protein data available for this gene.</i></p>" | |
| r = row.iloc[0] | |
| return ( | |
| f"<div class='stat-card'>" | |
| f"<p><b>Protein Length:</b> {int(r['protein_length'])} aa</p>" | |
| f"<p><b>Top Amino Acids:</b> {r['composition_summary']}</p>" | |
| f"</div>" | |
| ) | |
| def build_backpack_comparison(state: AppState, data: dict) -> go.Figure: | |
| """Bar chart of protein lengths for backpack genes.""" | |
| if not state or len(state.backpack_genes) < 2: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Pin at least 2 genes to compare", showarrow=False) | |
| return fig | |
| protein = data["protein"] | |
| bp_prot = protein[protein["gene_id"].isin(state.backpack_genes)] | |
| fig = go.Figure(data=[go.Bar( | |
| x=bp_prot["gene_id"], | |
| y=bp_prot["protein_length"], | |
| marker_color="#2E7D32", | |
| text=bp_prot["protein_length"], | |
| textposition="auto", | |
| )]) | |
| fig.update_layout( | |
| title="Backpack Genes: Protein Length Comparison", | |
| xaxis_title="Gene", | |
| yaxis_title="Protein Length (aa)", | |
| ) | |
| return fig | |
| def build_composition_heatmap(state: AppState, data: dict) -> go.Figure: | |
| """Heatmap of amino acid composition for backpack genes.""" | |
| if not state or len(state.backpack_genes) < 2: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Pin at least 2 genes to compare", showarrow=False) | |
| return fig | |
| # Parse composition from summary strings | |
| protein = data["protein"] | |
| bp_prot = protein[protein["gene_id"].isin(state.backpack_genes)] | |
| aa_data = {} | |
| for _, row in bp_prot.iterrows(): | |
| gene_id = row["gene_id"] | |
| comp = row["composition_summary"] | |
| aa_dict = {} | |
| for item in comp.split(", "): | |
| parts = item.split(":") | |
| if len(parts) == 2: | |
| aa = parts[0].strip() | |
| pct = float(parts[1].replace("%", "")) | |
| aa_dict[aa] = pct | |
| aa_data[gene_id] = aa_dict | |
| if not aa_data: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No composition data", showarrow=False) | |
| return fig | |
| df = pd.DataFrame(aa_data).fillna(0).T | |
| fig = go.Figure(data=go.Heatmap( | |
| z=df.values, | |
| x=df.columns.tolist(), | |
| y=df.index.tolist(), | |
| colorscale="YlGn", | |
| colorbar_title="%", | |
| )) | |
| fig.update_layout( | |
| title="Amino Acid Composition Heatmap", | |
| xaxis_title="Amino Acid", | |
| yaxis_title="Gene", | |
| ) | |
| return fig | |
| # ============================================================ | |
| # Gene Card Callbacks | |
| # ============================================================ | |
| def on_open_gene_card(gene_id: str, state: AppState, data: dict) -> tuple: | |
| """Open Gene Card side panel.""" | |
| if not gene_id: | |
| return "", False, state | |
| state.selected_gene = gene_id | |
| card = build_gene_card(gene_id, data) | |
| html = render_gene_card_html(card) | |
| state.award("Gene Hunter") | |
| return html, True, state | |
| def on_download_gene_report(state: AppState, data: dict) -> str: | |
| """Download gene report.""" | |
| if state and state.selected_gene: | |
| return export_gene_report(state.selected_gene, data) | |
| return None | |
| # ============================================================ | |
| # Final Report Callbacks | |
| # ============================================================ | |
| def on_generate_report(state: AppState, data: dict) -> tuple: | |
| """Generate field report.""" | |
| if state is None: | |
| state = AppState() | |
| state.award("Cartographer") | |
| report_md = generate_field_report(state, data) | |
| json_path = export_report_json(state, data) | |
| csv_path = export_report_csv(state, data) | |
| # Achievement HTML | |
| badges = " ".join( | |
| f'<span class="achievement-badge">{a}</span>' | |
| for a in sorted(state.achievements) | |
| ) | |
| return ( | |
| report_md, | |
| gr.File(value=json_path, visible=True), | |
| gr.File(value=csv_path, visible=True), | |
| badges, | |
| state, | |
| ) | |
| # ============================================================ | |
| # Data Health | |
| # ============================================================ | |
| def build_data_health_html(validation_report: dict) -> str: | |
| """Build data health HTML from validation report.""" | |
| rows = "" | |
| for k, v in validation_report.items(): | |
| if isinstance(v, float): | |
| v = f"{v:.1f}%" | |
| rows += f"<tr><td><b>{k}</b></td><td>{v}</td></tr>" | |
| return f"<table style='width:100%'>{rows}</table>" | |
| # Need gr import for Tabs update | |
| import gradio as gr | |