Spaces:
Sleeping
Sleeping
| """ | |
| Visualization functions for ChatSpatial Engine. | |
| Generates publication-quality plots from Phoenix expression predictions. | |
| """ | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from typing import Optional | |
| import io | |
| from PIL import Image as PILImage | |
| # Color palette | |
| COLORS = { | |
| "immune": "#4ECDC4", | |
| "tumor": "#FF6B6B", | |
| "stroma": "#95E1D3", | |
| "bg": "#0F1419", | |
| "card_bg": "#1A1F2E", | |
| "text": "#E8ECF0", | |
| "accent": "#7C3AED", | |
| "grid": "#2A3040", | |
| "border": "#303848", | |
| } | |
| GENE_CATEGORY = { | |
| "CD8A": "immune", "CD8B": "immune", "CD3D": "immune", "CD3E": "immune", | |
| "CD4": "immune", "MS4A1": "immune", "CD19": "immune", "CD68": "immune", | |
| "CD163": "immune", "PTPRC": "immune", "FOXP3": "immune", | |
| "EPCAM": "tumor", "KRT18": "tumor", "KRT7": "tumor", "MKI67": "tumor", "PCNA": "tumor", | |
| "COL1A1": "stroma", "VIM": "stroma", "ACTA2": "stroma", "FAP": "stroma", | |
| "VEGFA": "stroma", "PDCD1": "immune", "CD274": "immune", "CTLA4": "immune", | |
| "HLA-A": "immune", | |
| } | |
| def gene_expression_bar_chart(top_genes: list, marker_results: dict) -> Optional[plt.Figure]: | |
| """Horizontal bar chart of top expressed genes, color-coded by category.""" | |
| if not top_genes: | |
| return None | |
| genes = [g for g, _ in top_genes[:20]] | |
| values = [v for _, v in top_genes[:20]] | |
| fig, ax = plt.subplots(figsize=(8, 6), facecolor=COLORS["bg"]) | |
| ax.set_facecolor(COLORS["bg"]) | |
| bar_colors = [] | |
| for gene in genes: | |
| cat = GENE_CATEGORY.get(gene, "") | |
| if cat == "immune": | |
| bar_colors.append(COLORS["immune"]) | |
| elif cat == "tumor": | |
| bar_colors.append(COLORS["tumor"]) | |
| elif cat == "stroma": | |
| bar_colors.append(COLORS["stroma"]) | |
| else: | |
| bar_colors.append("#6366F1") | |
| y_pos = np.arange(len(genes)) | |
| bars = ax.barh(y_pos, values, color=bar_colors, edgecolor="none", height=0.7, alpha=0.85) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(genes, fontsize=9, color=COLORS["text"], fontfamily="monospace") | |
| ax.set_xlabel("Expression (log1p normalized)", fontsize=10, color=COLORS["text"]) | |
| ax.set_title("Top Expressed Genes", fontsize=13, color=COLORS["text"], fontweight="bold", pad=12) | |
| ax.tick_params(axis="x", colors=COLORS["text"], labelsize=8) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| ax.spines["bottom"].set_color(COLORS["border"]) | |
| ax.spines["left"].set_color(COLORS["border"]) | |
| ax.xaxis.grid(True, color=COLORS["grid"], alpha=0.3, linestyle="--") | |
| ax.set_axisbelow(True) | |
| legend_patches = [ | |
| mpatches.Patch(color=COLORS["immune"], label="Immune"), | |
| mpatches.Patch(color=COLORS["tumor"], label="Tumor"), | |
| mpatches.Patch(color=COLORS["stroma"], label="Stroma"), | |
| mpatches.Patch(color="#6366F1", label="Other"), | |
| ] | |
| ax.legend(handles=legend_patches, loc="lower right", fontsize=8, | |
| facecolor=COLORS["card_bg"], edgecolor=COLORS["border"], | |
| labelcolor=COLORS["text"], framealpha=0.9) | |
| ax.invert_yaxis() | |
| plt.tight_layout() | |
| return fig | |
| def tissue_composition_radar(cell_type_scores: dict) -> Optional[plt.Figure]: | |
| """Radar/spider plot showing immune vs tumor vs stroma composition.""" | |
| if not cell_type_scores: | |
| return None | |
| categories = ["Immune", "Tumor", "Stroma"] | |
| values = [ | |
| cell_type_scores.get("immune", 0), | |
| cell_type_scores.get("tumor", 0), | |
| cell_type_scores.get("stroma", 0), | |
| ] | |
| max_val = max(values) if max(values) > 0 else 1 | |
| values_norm = [v / max_val for v in values] | |
| values_norm += values_norm[:1] | |
| angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist() | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True), facecolor=COLORS["bg"]) | |
| ax.set_facecolor(COLORS["bg"]) | |
| ax.fill(angles, values_norm, color=COLORS["accent"], alpha=0.25) | |
| ax.plot(angles, values_norm, color=COLORS["accent"], linewidth=2.5, marker="o", markersize=8) | |
| for i, (angle, val, raw) in enumerate(zip(angles[:-1], values_norm[:-1], values)): | |
| color = [COLORS["immune"], COLORS["tumor"], COLORS["stroma"]][i] | |
| ax.plot(angle, val, "o", color=color, markersize=12, zorder=5) | |
| ax.annotate(f"{raw:.1f}", xy=(angle, val), fontsize=9, color=COLORS["text"], | |
| ha="center", va="bottom", fontweight="bold", | |
| xytext=(0, 10), textcoords="offset points") | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(categories, fontsize=11, color=COLORS["text"], fontweight="bold") | |
| ax.set_yticklabels([]) | |
| ax.spines["polar"].set_color(COLORS["border"]) | |
| ax.grid(color=COLORS["grid"], alpha=0.3) | |
| ax.set_title("Tissue Composition", fontsize=13, color=COLORS["text"], | |
| fontweight="bold", pad=20) | |
| plt.tight_layout() | |
| return fig | |
| def marker_heatmap(marker_results: dict) -> Optional[plt.Figure]: | |
| """Compact heatmap of marker gene expression levels.""" | |
| if not marker_results: | |
| return None | |
| sorted_markers = sorted(marker_results.items(), key=lambda x: -x[1]["value"])[:16] | |
| if not sorted_markers: | |
| return None | |
| genes = [m[0] for m in sorted_markers] | |
| values = [m[1]["value"] for m in sorted_markers] | |
| tiers = [m[1]["tier"] for m in sorted_markers] | |
| cmap = LinearSegmentedColormap.from_list( | |
| "expression", | |
| ["#1A1F2E", "#1E3A5F", "#4ECDC4", "#FFD93D", "#FF6B6B"], | |
| ) | |
| fig, ax = plt.subplots(figsize=(10, 2.5), facecolor=COLORS["bg"]) | |
| ax.set_facecolor(COLORS["bg"]) | |
| data = np.array(values).reshape(1, -1) | |
| im = ax.imshow(data, aspect="auto", cmap=cmap, vmin=0, vmax=max(values) * 1.1 if values else 1) | |
| ax.set_xticks(range(len(genes))) | |
| ax.set_xticklabels(genes, fontsize=8, color=COLORS["text"], rotation=45, ha="right", | |
| fontfamily="monospace") | |
| ax.set_yticks([]) | |
| for i, (val, tier) in enumerate(zip(values, tiers)): | |
| ax.text(i, 0, f"{val:.2f}", ha="center", va="center", | |
| fontsize=7, color="white" if val > max(values) * 0.5 else COLORS["text"], | |
| fontweight="bold") | |
| cbar = plt.colorbar(im, ax=ax, orientation="vertical", fraction=0.02, pad=0.02) | |
| cbar.set_label("Expression", fontsize=8, color=COLORS["text"]) | |
| cbar.ax.tick_params(colors=COLORS["text"], labelsize=7) | |
| ax.set_title("Marker Gene Expression Heatmap", fontsize=11, color=COLORS["text"], | |
| fontweight="bold", pad=8) | |
| ax.spines[:].set_visible(False) | |
| plt.tight_layout() | |
| return fig | |
| def generate_all_plots(phoenix_result: dict) -> dict: | |
| """Generate all visualization plots from Phoenix output. Returns dict of figures.""" | |
| plots = {} | |
| if "top_genes" in phoenix_result: | |
| fig = gene_expression_bar_chart( | |
| phoenix_result["top_genes"], | |
| phoenix_result.get("marker_results", {}), | |
| ) | |
| if fig: | |
| plots["bar_chart"] = fig | |
| if "cell_type_scores" in phoenix_result: | |
| fig = tissue_composition_radar(phoenix_result["cell_type_scores"]) | |
| if fig: | |
| plots["radar"] = fig | |
| if "marker_results" in phoenix_result: | |
| fig = marker_heatmap(phoenix_result["marker_results"]) | |
| if fig: | |
| plots["heatmap"] = fig | |
| return plots | |
| def fig_to_pil(fig: plt.Figure) -> PILImage.Image: | |
| """Convert matplotlib figure to PIL Image.""" | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=150, bbox_inches="tight", | |
| facecolor=fig.get_facecolor(), edgecolor="none") | |
| buf.seek(0) | |
| img = PILImage.open(buf).copy() | |
| buf.close() | |
| plt.close(fig) | |
| return img | |