import streamlit as st import os import base64 import pandas as pd import numpy as np from src.backend.data_loader import get_metadata def get_base64_of_bin_file(bin_file): with open(bin_file, 'rb') as f: data = f.read() return base64.b64encode(data).decode() def get_header_stats(): """Calculate real-time stats for the header banner for ALL organs using correct metadata columns.""" df = get_metadata() if df.empty: return { 'human': {'total': "0", 'spots': "0", 'organs': []}, 'mouse': {'total': "0", 'spots': "0", 'organs': []} } fmt = lambda x: f"{x:,}" spot_col = 'spots_under_tissue' if 'spots_under_tissue' in df.columns else None def get_species_stats(species_mask): spec_df = df[species_mask] total_samples = len(spec_df) if spot_col: spec_df[spot_col] = pd.to_numeric(spec_df[spot_col], errors='coerce').fillna(0) total_spots = spec_df[spot_col].sum() else: total_spots = 0 org_groups = spec_df.groupby('organ') organs_data = [] for name, group in org_groups: s_count = len(group) spots = group[spot_col].sum() if spot_col else 0 organs_data.append({ 'name': name.upper(), 'samples': fmt(s_count), 'spots': fmt(int(spots)) if spots > 0 else "0" }) organs_data.sort(key=lambda x: int(x['samples'].replace(',', '')), reverse=True) return { 'total': fmt(total_samples), 'spots': fmt(int(total_spots)) if total_spots > 0 else "0", 'organs': organs_data } human_mask = df['species'].str.contains('human|homo', case=False, na=False) mouse_mask = df['species'].str.contains('mouse|mus', case=False, na=False) return { 'human': get_species_stats(human_mask), 'mouse': get_species_stats(mouse_mask) } def render_header(): """Render a premium atlas header with optimized glassmorphism cards using st.html.""" load_css() h_img_path = "assets/human_red.png" m_img_path = "assets/mouse_red.png" bg_img_path = "assets/network_bg_red.png" h_base64 = get_base64_of_bin_file(h_img_path) if os.path.exists(h_img_path) else "" m_base64 = get_base64_of_bin_file(m_img_path) if os.path.exists(m_img_path) else "" bg_base64 = get_base64_of_bin_file(bg_img_path) if os.path.exists(bg_img_path) else "" stats = get_header_stats() def build_circular_organs(organs_list, radius=290): N = len(organs_list) html = "" for i, org in enumerate(organs_list): angle = (i / N) * 2 * np.pi - (np.pi / 2) x = radius * np.cos(angle) y = radius * np.sin(angle) html += f'''
{org['name']}
Samples: {org['samples']}
Spots: {org['spots']}
''' return html h_bubbles = build_circular_organs(stats['human']['organs'], radius=290) m_bubbles = build_circular_organs(stats['mouse']['organs'], radius=290) subtitle = "A spatial atlas of tumour microenvironment metabolism and metabolic interactions inferred by a pretrained self-supervised metabolic hypergraph" header_html = f"""

spMetaTME-Atlas

{subtitle}

{h_bubbles}
HUMAN ATLAS {stats['human']['total']} Samples | {stats['human']['spots']} Spots
{m_bubbles}
MOUSE ATLAS {stats['mouse']['total']} Samples | {stats['mouse']['spots']} Spots
""" st.html(header_html) @st.cache_resource(show_spinner=False) def load_css(): """Load and apply CSS - cached to prevent reloading on every rerun.""" css_path = "assets/style.css" css_content = "" if os.path.exists(css_path): with open(css_path) as f: css_content = f.read() st.markdown(""" """, unsafe_allow_html=True) if css_content: st.markdown(f"", unsafe_allow_html=True)