wack0's picture
Update app.py
020350e verified
import marimo
__generated_with = "0.21.1"
app = marimo.App(width="medium")
@app.cell
async def _():
import sys
import io
import math
import marimo as mo
import pandas as pd
import openpyxl
return io, math, mo, pd
@app.cell
def _(mo):
# Buttons to upload the data files required for the visualisation
upload_yield = mo.ui.file(label="Upload: Coffee_yield.xlsx", kind="button", multiple=False)
upload_species = mo.ui.file(label="Upload: Plant_species_and_average...xlsx", kind="button", multiple=False)
upload_decomp = mo.ui.file(label="Upload: Total_species_composition.xlsx", kind="button", multiple=False)
# Assign to variable and display the UI
upload_ui = mo.vstack([
mo.md("### Please provide the correct data files to view the visual"),
upload_yield,
upload_species,
upload_decomp
], align="center")
# Place as the last statement to ensure Marimo renders it!
upload_ui
return upload_decomp, upload_species, upload_yield, upload_ui
@app.cell
def _(io, mo, pd, upload_decomp, upload_species, upload_yield):
# Exectution of this cell and everything below is paused until all files are uploaded
mo.stop(
not upload_yield.value or not upload_species.value or not upload_decomp.value,
mo.md("*Waiting for all three files to be uploaded...*")
)
###### PREPARATION ######
# Read the uploaded files from browser memory
df_yield = pd.read_excel(io.BytesIO(upload_yield.value[0].contents))
df_species = pd.read_excel(io.BytesIO(upload_species.value[0].contents))
df_decomposition = pd.read_excel(io.BytesIO(upload_decomp.value[0].contents))
# Hardcoded column names used in the data files
COL_SPECIES_NAME = "Species name"
COL_SPECIES_GROUP = "Species group"
COL_DECOMP_SPECIES = df_decomposition.columns[0]
# Build a site -> yield lookup from Coffee_yield.xlsx
site_yield_map = dict(zip(df_yield["Site ID"].astype(str), df_yield["Mean_CC_Yield"]))
# In `Plant_species_and_average_yield.xlsx`, empty group cells are filled
df_species[COL_SPECIES_GROUP] = df_species[COL_SPECIES_GROUP].ffill()
GROUPS = df_species[COL_SPECIES_GROUP].dropna().unique().tolist()
# In `Total_species_decomposition.xlsx`, set index to species for easy row lookups
df_decomposition.set_index(COL_DECOMP_SPECIES, inplace=True)
ALL_SITES = [str(col) for col in df_decomposition.columns]
# Build the species_data dictionary >> dictionary used to build the visual later
species_data = []
for idx, row in df_species.iterrows():
sp_id = str(row[COL_SPECIES_NAME])
group = str(row[COL_SPECIES_GROUP]) if COL_SPECIES_GROUP in df_species.columns else GROUPS[idx % len(GROUPS)]
# First: find which sites this species occurs in
present_in = []
if sp_id in df_decomposition.index:
species_row = df_decomposition.loc[sp_id]
sites_with_species = species_row[species_row == 1]
present_in = sites_with_species.index.astype(str).tolist()
# Then: compute mean yield across those sites using Coffee_yield.xlsx
site_yields = [site_yield_map[site] for site in present_in if site in site_yield_map]
avg_yield = sum(site_yields) / len(site_yields) if site_yields else 0.0
species_data.append({
"id": sp_id,
"group": group,
"yield": avg_yield,
"num_sites": len(present_in),
"sites": present_in
})
# Sort species from most common (center) to least common (edge)
species_data.sort(key=lambda x: x["num_sites"], reverse=True)
return ALL_SITES, GROUPS, species_data
@app.cell
def _(ALL_SITES, mo, species_data):
# Create UI controls
total_species = len(species_data)
slider_count = mo.ui.slider(
start=5, stop=total_species, step=1, value=int(total_species/2), label="Species shown:"
)
drop_single = mo.ui.dropdown(
options=ALL_SITES, value=ALL_SITES[0] if ALL_SITES else "", label="Highlight Site:"
)
drop_comp1 = mo.ui.dropdown(
options=ALL_SITES, value=ALL_SITES[0] if ALL_SITES else "", label="Compare Site 1:"
)
drop_comp2 = mo.ui.dropdown(
options=ALL_SITES, value=ALL_SITES[1] if len(ALL_SITES) > 1 else ALL_SITES[0], label="Compare Site 2:"
)
return drop_comp1, drop_comp2, drop_single, slider_count
@app.cell
def _(mo):
dropdown_styles = mo.Html("""
<style>
.marimo-dropdown select, select {
appearance: none;
-webkit-appearance: none;
background-color: #1e1e2e;
color: #cdd6f4;
padding: 8px 36px 8px 14px;
border: 1.5px solid #45475a;
border-radius: 10px;
font-size: 13px;
font-family: 'Inter', sans-serif;
cursor: pointer;
min-width: 160px;
transition: border-color 0.2s, box-shadow 0.2s;
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 24 24'%3E%3Cpath fill='%2389b4fa' d='M7 10l5 5 5-5z'/%3E%3C/svg%3E");
background-repeat: no-repeat;
background-position: right 10px center;
}
select:hover { border-color: #89b4fa; }
select:focus {
outline: none;
border-color: #89b4fa;
box-shadow: 0 0 0 3px rgba(137, 180, 250, 0.20);
}
label {
font-size: 11px;
font-family: 'Inter', sans-serif;
font-weight: 600;
letter-spacing: 0.05em;
text-transform: uppercase;
color: #a6adc8;
margin-bottom: 4px;
display: block;
}
</style>
""")
tabs = mo.ui.tabs({
"General Overview": mo.md("*Viewing all species colored by their primary group.*"),
"Individual Site": mo.md("*Select a site using the dropdown below.*"),
"Compare Sites": mo.md("*Select two sites to compare using the dropdowns below.*"),
})
return dropdown_styles, tabs
@app.cell
def _(
drop_comp1,
drop_comp2,
drop_single,
dropdown_styles,
mo,
slider_count,
tabs,
):
_active = tabs.value
if _active == "Individual Site":
_site_selector = mo.hstack([drop_single], justify="center")
elif _active == "Compare Sites":
_site_selector = mo.hstack([drop_comp1, drop_comp2], gap=4)
else:
_site_selector = mo.Html("")
controls = mo.vstack([
dropdown_styles,
tabs,
_site_selector,
mo.hstack([slider_count], justify="center")
], align="center", gap=4)
return (controls,)
@app.cell
def _(math, tabs):
# Sunburst chart: dimensions + position
CX, CY = 500, 380
MAX_RADIUS = 350
MIN_RADIUS = 80
TW, TH = 245, 105
# Colors
preferred_hues = [30, 120, 210]
SITE_LEGEND_BKG = "#1e1e2e"
SITE_LEGEND_BORDER = "#45475a"
SITE_LEGEND_TEXT = "#cdd6f4"
SITE_COLOR_S1_ONLY = "#f5b0c6"
SITE_COLOR_S2_ONLY = "#d8b4fe"
SITE_COLOR_BOTH = "#f9e2af"
SITE_COLOR_NEITHER = "#e0e0e0"
# Helper functions
def polar_to_cartesian(cx, cy, r, angle_deg):
rad = math.radians(angle_deg)
return cx + r * math.cos(rad), cy + r * math.sin(rad)
def build_arc(cx, cy, r_inner, r_outer, start_angle, end_angle):
if end_angle - start_angle <= 0.05:
end_angle = start_angle + 0.05
p1 = polar_to_cartesian(cx, cy, r_outer, start_angle)
p2 = polar_to_cartesian(cx, cy, r_outer, end_angle)
p3 = polar_to_cartesian(cx, cy, r_inner, end_angle)
p4 = polar_to_cartesian(cx, cy, r_inner, start_angle)
large_arc = "0" if end_angle - start_angle <= 180 else "1"
return f"M {p1[0]} {p1[1]} A {r_outer} {r_outer} 0 {large_arc} 1 {p2[0]} {p2[1]} L {p3[0]} {p3[1]} A {r_inner} {r_inner} 0 {large_arc} 0 {p4[0]} {p4[1]} Z"
def make_tooltip(unique_id, tx, ty, group_name, species_name, avg_yield, num_sites):
# Bound the tooltip coordinates so it doesn't clip off the 1000x1000 SVG canvas
tx = max(10, min(tx, 1000 - TW - 10))
ty = max(10, min(ty, 1000 - TH - 10))
safe_group = group_name.strip().upper()[:40]
safe_name = species_name.strip()[:38]
yield_str = f"Average yield: {avg_yield:.1f} kg ha\u207b\u00b9"
sites_str = f"Occurs in: {num_sites} site(s)"
tip_id = unique_id.replace("seg", "tip")
return f"""
<g id="{tip_id}" class="tip" transform="translate({tx},{ty})">
<rect width="{TW}" height="{TH}" rx="10" ry="10"
fill="#1e1e2e" stroke="#45475a" stroke-width="1.5"
filter="url(#tipshadow)"/>
<text x="12" y="22" font-family="Inter,sans-serif" font-size="10"
font-weight="700" letter-spacing="1" fill="#89b4fa">{safe_group}</text>
<line x1="12" y1="30" x2="{TW-12}" y2="30" stroke="#45475a" stroke-width="1"/>
<text x="12" y="50" font-family="Inter,sans-serif" font-size="13"
font-weight="700" fill="#cdd6f4">{safe_name}</text>
<text x="12" y="72" font-family="Inter,sans-serif" font-size="12"
fill="#a6e3a1">{yield_str}</text>
<text x="12" y="92" font-family="Inter,sans-serif" font-size="12"
fill="#f9e2af">{sites_str}</text>
</g>"""
_ = tabs
return (
CX,
CY,
MAX_RADIUS,
MIN_RADIUS,
SITE_COLOR_BOTH,
SITE_COLOR_NEITHER,
SITE_COLOR_S1_ONLY,
SITE_COLOR_S2_ONLY,
SITE_LEGEND_BKG,
SITE_LEGEND_BORDER,
SITE_LEGEND_TEXT,
TH,
TW,
build_arc,
make_tooltip,
polar_to_cartesian,
preferred_hues,
)
@app.cell
def _(GROUPS, preferred_hues, slider_count, species_data, tabs):
# Filtering data + active tab
active_data = species_data[:slider_count.value]
active_tab = tabs.value
group_hues = {
group_name: preferred_hues[i % len(preferred_hues)]
for i, group_name in enumerate(GROUPS)
}
grouped_data = {g: [] for g in GROUPS}
for s in active_data:
grouped_data[s["group"]].append(s)
return active_tab, group_hues, grouped_data
@app.cell
def _(
CX,
CY,
GROUPS,
MAX_RADIUS,
MIN_RADIUS,
SITE_COLOR_BOTH,
SITE_COLOR_NEITHER,
SITE_COLOR_S1_ONLY,
SITE_COLOR_S2_ONLY,
TH,
TW,
active_tab,
build_arc,
drop_comp1,
drop_comp2,
drop_single,
group_hues,
grouped_data,
make_tooltip,
polar_to_cartesian,
):
# Build the core sunburst paths
core_paths = []
tooltip_elements = []
css_hover_rules = []
tab_id_prefix = active_tab.replace(" ", "_").lower() # unique tab prefix
seg_index = 0 # unique id counter shared across all segments
for group_idx, group_name in enumerate(GROUPS):
group_items = grouped_data[group_name]
if not group_items:
continue
base_angle = group_idx * 120 # hardcoded 120 degree angle bcs only 3 groups
tiers = {}
for item in group_items:
sites = item["num_sites"]
if sites not in tiers:
tiers[sites] = []
tiers[sites].append(item)
sorted_tier_keys = sorted(tiers.keys(), reverse=True)
num_tiers = len(sorted_tier_keys)
ring_thickness = (MAX_RADIUS - MIN_RADIUS) / max(1, num_tiers)
for tier_idx, sites_key in enumerate(sorted_tier_keys):
tier_items = tiers[sites_key]
r_inner = MIN_RADIUS + (tier_idx * ring_thickness) # calculate segment properties (width / radius / ...)
r_outer = r_inner + ring_thickness
r_center = (r_inner + r_outer) / 2
total_yield = sum(item["yield"] for item in tier_items)
current_start_angle = base_angle
for item in tier_items:
unique_seg_id = f"{tab_id_prefix}-seg-{seg_index}"
unique_tip_id = f"{tab_id_prefix}-tip-{seg_index}"
if total_yield > 0:
sweep_angle = (item["yield"] / total_yield) * 120
else:
sweep_angle = 120 / len(tier_items)
end_angle = current_start_angle + sweep_angle
opacity = 1.0
fill_color = ""
if active_tab == "Compare Sites": # coloring depends on which tab the user is viewing
in_s1 = drop_comp1.value in item["sites"]
in_s2 = drop_comp2.value in item["sites"]
if in_s1 and in_s2: fill_color = SITE_COLOR_BOTH
elif in_s1: fill_color = SITE_COLOR_S1_ONLY
elif in_s2: fill_color = SITE_COLOR_S2_ONLY
else: fill_color = SITE_COLOR_NEITHER
elif active_tab == "Individual Site":
hue = group_hues[group_name]
lightness = 85 - min(50, (item["yield"] / max(1, total_yield)) * 45)
fill_color = f"hsl({hue}, 70%, {lightness}%)"
if drop_single.value not in item["sites"]:
opacity = 0.15
else:
hue = group_hues[group_name]
lightness = 85 - min(50, (item["yield"] / max(1, total_yield)) * 45)
fill_color = f"hsl({hue}, 70%, {lightness}%)"
path_d = build_arc(CX, CY, r_inner, r_outer, current_start_angle, end_angle)
angle_center = current_start_angle + sweep_angle / 2
core_paths.append(
f'<path id="seg-{unique_seg_id}" d="{path_d}" '
f'fill="{fill_color}" opacity="{opacity}" '
f'stroke="white" stroke-width="2" cursor="pointer"/>'
)
tip_r = r_outer + 20
tip_cx, tip_cy = polar_to_cartesian(CX, CY, tip_r, angle_center)
tx = tip_cx if tip_cx < CX else tip_cx - TW
ty = tip_cy - TH / 2
tooltip_elements.append(
make_tooltip(unique_seg_id, tx, ty, group_name,
item["id"], item["yield"], item["num_sites"])
)
# add hovering effect
css_hover_rules.append(
f"svg:has(#seg-{unique_seg_id}:hover) #seg-{unique_seg_id} "
f"{{ filter: brightness(1.35) drop-shadow(0 0 7px rgba(255,255,255,0.6)); }}\n"
f"svg:has(#seg-{unique_seg_id}:hover) #{unique_tip_id} "
f"{{ visibility: visible; }}"
)
current_start_angle = end_angle
seg_index += 1
return core_paths, css_hover_rules, tooltip_elements
@app.cell
def _(
CX,
CY,
MAX_RADIUS,
MIN_RADIUS,
SITE_COLOR_BOTH,
SITE_COLOR_NEITHER,
SITE_COLOR_S1_ONLY,
SITE_COLOR_S2_ONLY,
SITE_LEGEND_BKG,
SITE_LEGEND_BORDER,
SITE_LEGEND_TEXT,
active_tab,
drop_comp1,
drop_comp2,
polar_to_cartesian,
):
# Build annotations and legends
annotation_elements = []
if active_tab == "Compare Sites": # legend for Compare Sites tab
site1_label = drop_comp1.value
site2_label = drop_comp2.value
legend_items = [
(SITE_COLOR_S1_ONLY, f"Only in {site1_label}"),
(SITE_COLOR_S2_ONLY, f"Only in {site2_label}"),
(SITE_COLOR_BOTH, f"In both sites"),
(SITE_COLOR_NEITHER, "In neither site"),
]
LX, LY = 0, 0 # top-left corner of the legend box
LW, LH = 210, 150 # box dimensions
ROW_H = 26 # vertical spacing between rows
SWATCH_S = 14 # swatch square size
annotation_elements.append(
f'<g id="compare-legend" style="pointer-events:none;">'
f'<rect x="{LX}" y="{LY}" width="{LW}" height="{LH}" rx="10" ry="10" '
f'fill="{SITE_LEGEND_BKG}" stroke="{SITE_LEGEND_BORDER}" stroke-width="1.5" '
f'filter="url(#tipshadow)"/>'
f'<text x="{LX+12}" y="{LY+22}" font-family="Inter,sans-serif" '
f'font-size="11" font-weight="700" letter-spacing="1" fill="#89b4fa">'
f'SITE COMPARISON</text>'
f'<line x1="{LX+12}" y1="{LY+30}" x2="{LX+LW-12}" y2="{LY+30}" '
f'stroke="{SITE_LEGEND_BORDER}" stroke-width="1"/>'
)
for i, (color, label) in enumerate(legend_items):
row_y = LY + 42 + i * ROW_H
annotation_elements.append(
f'<rect x="{LX+12}" y="{row_y}" width="{SWATCH_S}" height="{SWATCH_S}" '
f'rx="3" fill="{color}" stroke="white" stroke-width="1"/>'
f'<text x="{LX+12+SWATCH_S+10}" y="{row_y+11}" '
f'font-family="Inter,sans-serif" font-size="12" fill="{SITE_LEGEND_TEXT}">{label}</text>'
)
annotation_elements.append('</g>')
# dividers
for i in range(3):
angle = i * 120
p1 = polar_to_cartesian(CX, CY, MIN_RADIUS, angle)
p2 = polar_to_cartesian(CX, CY, MAX_RADIUS, angle)
annotation_elements.append(
f'<line x1="{p1[0]}" y1="{p1[1]}" x2="{p2[0]}" y2="{p2[1]}" '
f'stroke="#fff" stroke-width="4" style="pointer-events:none;"/>'
)
# group labels
lbl_1 = polar_to_cartesian(CX, CY, MAX_RADIUS + 25, 60)
lbl_2 = polar_to_cartesian(CX, CY, MAX_RADIUS + 25, 180)
lbl_3 = polar_to_cartesian(CX, CY, MAX_RADIUS + 25, 300)
annotation_elements.extend([
f'<text x="{lbl_1[0]}" y="{lbl_1[1]}" font-family="sans-serif" font-weight="bold" fill="black" text-anchor="middle" style="pointer-events:none;">Woody vascular plants</text>',
f'<text x="{lbl_2[0]}" y="{lbl_2[1]}" font-family="sans-serif" font-weight="bold" fill="black" text-anchor="middle" style="pointer-events:none;">Non-woody vascular plants</text>',
f'<text x="{lbl_3[0]}" y="{lbl_3[1]}" font-family="sans-serif" font-weight="bold" fill="black" text-anchor="middle" style="pointer-events:none;">Bryophytes</text>',
])
return (annotation_elements,)
@app.cell
def _(active_tab, drop_single, mo, species_data):
# Build the recommendations panel for the Individual Site tab
recommendations_panel = mo.Html("")
if active_tab == "Individual Site" and drop_single.value:
selected_site = drop_single.value
present = [s for s in species_data if selected_site in s["sites"]]
not_present = [s for s in species_data if selected_site not in s["sites"]]
# Top 5 to REMOVE: species present at site with the lowest avg yield
to_remove = sorted(present, key=lambda x: x["yield"])[:5]
# Top 5 to ADD: species not present at site with the highest avg yield
to_add = sorted(not_present, key=lambda x: x["yield"], reverse=True)[:5]
def make_row(rank, species, color):
return f"""
<div style="display:flex; align-items:center; gap:10px; padding:6px 0;
border-bottom:1px solid #313244;">
<span style="font-size:11px; color:#585b70; min-width:18px;">#{rank}</span>
<div style="flex:1;">
<div style="font-size:13px; font-weight:600; color:#cdd6f4;">
{species['id'][:35]}
</div>
<div style="font-size:11px; color:#a6adc8;">
{species['group']} &nbsp;·&nbsp; occurs in {species['num_sites']} site(s)
</div>
</div>
<span style="font-size:12px; font-weight:700; color:{color};">
{species['yield']:.1f} kg ha⁻¹
</span>
</div>"""
remove_rows = "".join(make_row(i+1, s, "#f38ba8") for i, s in enumerate(to_remove))
add_rows = "".join(make_row(i+1, s, "#a6e3a1") for i, s in enumerate(to_add))
recommendations_panel = mo.Html(f"""
<div style="display:flex; gap:20px; width:1000px; font-family:'Inter',sans-serif;
box-sizing:border-box;">
<div style="flex:1; background:#1e1e2e; border:1.5px solid #45475a;
border-radius:12px; padding:16px 20px;">
<div style="font-size:10px; font-weight:700; letter-spacing:1px;
color:#f38ba8; margin-bottom:4px;">CONSIDER REMOVING</div>
<div style="font-size:12px; color:#585b70; margin-bottom:12px;">
Lowest-yield species currently at <strong style="color:#cdd6f4;">
{selected_site}</strong>
</div>
{remove_rows}
</div>
<div style="flex:1; background:#1e1e2e; border:1.5px solid #45475a;
border-radius:12px; padding:16px 20px;">
<div style="font-size:10px; font-weight:700; letter-spacing:1px;
color:#a6e3a1; margin-bottom:4px;">CONSIDER ADDING</div>
<div style="font-size:12px; color:#585b70; margin-bottom:12px;">
Highest-yield species absent from <strong style="color:#cdd6f4;">
{selected_site}</strong>
</div>
{add_rows}
</div>
</div>
""")
return (recommendations_panel,)
@app.cell
def _(
CX,
CY,
MAX_RADIUS,
MIN_RADIUS,
active_tab,
annotation_elements,
controls,
core_paths,
css_hover_rules,
mo,
recommendations_panel,
tooltip_elements,
):
# Combine everything + visualize
render_key = active_tab.replace(" ", "-").lower()
css = f"""
<style>
#{render_key}-container .tip {{ visibility: hidden; pointer-events: none; }}
#{render_key}-container path[id^="seg-"] {{ transition: filter 0.15s; }}
{chr(10).join([f"#{render_key}-container {rule}" for rule in css_hover_rules])}
</style>
"""
svg_markup = f"""
<div id="{render_key}-container">
{css}
<svg width="1000" height="800" viewBox="0 0 1000 800" xmlns="http://www.w3.org/2000/svg">
<defs>
<filter id="tipshadow" x="-5%" y="-5%" width="120%" height="130%">
<feDropShadow dx="0" dy="3" stdDeviation="5" flood-color="#000" flood-opacity="0.45"/>
</filter>
</defs>
<circle cx="{CX}" cy="{CY}" r="{MIN_RADIUS}" fill="#f4f4f4" stroke="none"/>
{"".join(core_paths)}
{"".join(annotation_elements)}
{"".join(tooltip_elements)}
</svg>
</div>
"""
final_dashboard = mo.vstack(
[controls, mo.Html(svg_markup), recommendations_panel],
align="center", gap=0
)
final_dashboard
return
if __name__ == "__main__":
app.run()