rntc's picture
Upload app.py with huggingface_hub
f65575b verified
"""
Gradio app to explore pancreas cancer clinical report annotations.
"""
import gradio as gr
from datasets import load_dataset
# Load the dataset
print("Loading dataset...")
full_dataset = load_dataset("rntc/biomed-fr-pancreas-annotations", split="train")
print(f"Loaded {len(full_dataset)} samples")
# Filter: keep only samples with >= 10 real annotations
MIN_ANNOTATIONS = 10
def count_real_annotations(annotation):
"""Count real annotations (excluding 'not found' placeholders)."""
count = 0
for var_data in annotation.values():
if var_data and isinstance(var_data, dict):
value = var_data.get("value")
span = var_data.get("span", "")
if value:
if span and "pas de mention" in span.lower():
continue
if "not performed" in str(value).lower():
continue
count += 1
return count
# Filter dataset
filtered_indices = [
i for i, sample in enumerate(full_dataset)
if count_real_annotations(sample.get("annotation", {})) >= MIN_ANNOTATIONS
]
print(f"Filtered to {len(filtered_indices)} samples with >= {MIN_ANNOTATIONS} annotations")
# Colors for highlighting
COLORS = [
"#FFEB3B", "#4CAF50", "#2196F3", "#FF9800", "#E91E63",
"#9C27B0", "#00BCD4", "#8BC34A", "#FF5722", "#607D8B",
]
def escape_html(text):
if not text:
return ""
return str(text).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
def highlight_text(cr_text, annotation):
"""Highlight spans in CR text."""
if not cr_text or not annotation:
return f"<pre style='white-space:pre-wrap;'>{escape_html(cr_text)}</pre>"
# Collect valid spans (that exist in text)
spans = []
for var_name, var_data in annotation.items():
if var_data and isinstance(var_data, dict):
span = var_data.get("span")
value = var_data.get("value")
if span and value and span in cr_text:
spans.append({
"text": span,
"start": cr_text.find(span),
"var": var_name.replace("_", " ").title(),
"value": str(value)
})
if not spans:
return f"<pre style='white-space:pre-wrap;'>{escape_html(cr_text)}</pre>"
# Sort by position and remove overlaps
spans.sort(key=lambda x: x["start"])
filtered = []
for s in spans:
s["end"] = s["start"] + len(s["text"])
if not filtered or s["start"] >= filtered[-1]["end"]:
filtered.append(s)
# Build HTML
html = []
pos = 0
color_map = {}
color_idx = 0
for s in filtered:
if s["start"] > pos:
html.append(escape_html(cr_text[pos:s["start"]]))
if s["var"] not in color_map:
color_map[s["var"]] = COLORS[color_idx % len(COLORS)]
color_idx += 1
color = color_map[s["var"]]
html.append(
f'<mark style="background:{color};padding:1px 3px;border-radius:3px;" '
f'title="{escape_html(s["var"])}: {escape_html(s["value"])}">'
f'{escape_html(s["text"])}</mark>'
)
pos = s["end"]
if pos < len(cr_text):
html.append(escape_html(cr_text[pos:]))
return f"<pre style='white-space:pre-wrap;line-height:1.6;'>{''.join(html)}</pre>"
def format_table(annotation):
"""Format annotations as HTML table."""
if not annotation:
return "<p>No annotations</p>"
rows = []
for var_name, var_data in annotation.items():
if var_data and isinstance(var_data, dict):
value = var_data.get("value")
span = var_data.get("span", "")
var_label = var_name.replace("_", " ").title()
if value:
if span and "pas de mention" in span.lower():
display_value = "/"
display_span = ""
elif "not performed" in str(value).lower():
display_value = "/"
display_span = ""
else:
display_value = str(value)
display_span = span[:60] + "..." if span and len(span) > 60 else (span or "")
else:
display_value = "/"
display_span = ""
rows.append(f"""<tr>
<td style="padding:6px 10px;border-bottom:1px solid #ddd;font-weight:500;">{escape_html(var_label)}</td>
<td style="padding:6px 10px;border-bottom:1px solid #ddd;color:#1565C0;">{escape_html(display_value)}</td>
<td style="padding:6px 10px;border-bottom:1px solid #ddd;color:#666;font-size:12px;font-style:italic;">{escape_html(display_span)}</td>
</tr>""")
return f"""<table style="width:100%;border-collapse:collapse;font-size:13px;">
<thead><tr style="background:#f5f5f5;">
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Variable</th>
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Value</th>
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Source</th>
</tr></thead>
<tbody>{"".join(rows)}</tbody>
</table>"""
def display_sample(slider_idx):
"""Display a sample."""
slider_idx = int(slider_idx)
if slider_idx < 0 or slider_idx >= len(filtered_indices):
return "Invalid", "Invalid", "Invalid"
real_idx = filtered_indices[slider_idx]
sample = full_dataset[real_idx]
original = sample.get("original_text", "")
cr = sample.get("CR", "")
annotation = sample.get("annotation", {})
n_annotations = count_real_annotations(annotation)
original_html = f"<pre style='white-space:pre-wrap;line-height:1.6;'>{escape_html(original)}</pre>"
cr_html = f"<p><b>Sample #{real_idx}</b> — {n_annotations} annotations</p>" + highlight_text(cr, annotation)
return original_html, cr_html, format_table(annotation)
# Build UI
with gr.Blocks(title="Pancreas Annotations", theme=gr.themes.Base()) as demo:
gr.Markdown("# 🔬 Pancreas Cancer Annotations Explorer")
gr.Markdown(f"Showing {len(filtered_indices)} samples with >= {MIN_ANNOTATIONS} annotations. Hover over highlights to see values.")
with gr.Row():
slider = gr.Slider(0, len(filtered_indices) - 1, value=0, step=1, label="Sample")
with gr.Row():
with gr.Column():
gr.Markdown("### Original (English)")
original_html = gr.HTML()
with gr.Column():
gr.Markdown("### Generated CR (French)")
cr_html = gr.HTML()
with gr.Column():
gr.Markdown("### Extracted Variables")
table_html = gr.HTML()
slider.change(display_sample, inputs=[slider], outputs=[original_html, cr_html, table_html])
demo.load(display_sample, inputs=[slider], outputs=[original_html, cr_html, table_html])
if __name__ == "__main__":
demo.launch()