ZuCo-EEG-Lab / app.py
st192011's picture
Update app.py
4bcefdd verified
import gradio as gr
import torch
import numpy as np
import pandas as pd
import random
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# --- 1. LOAD ARTIFACTS ---
PKG_PATH = "neuro_semantic_package.pt"
print("๐Ÿš€ System Startup: Loading Artifacts...")
if not os.path.exists(PKG_PATH):
# Fallback for local testing if file isn't in root
POSSIBLE_PATHS = [
"neuro_semantic_package.pt",
"/content/drive/MyDrive/Brain2Text_Project/demo_research_v2/neuro_semantic_package.pt"
]
for p in POSSIBLE_PATHS:
if os.path.exists(p):
PKG_PATH = p
break
if not os.path.exists(PKG_PATH):
raise FileNotFoundError(f"CRITICAL: '{PKG_PATH}' missing. Please upload the .pt file.")
# Load the "Black Box" package
# map_location='cpu' ensures it runs on basic HF spaces without GPU if needed
PKG = torch.load(PKG_PATH, map_location="cpu", weights_only=False)
DATA = PKG['data']
MODELS = PKG['models'] # The Projectors
MATRIX = PKG['matrix'] # Pre-calculated Accuracy Table
MAPPING = PKG['mapping_key'] # Secret Mapping
# Inverse mapping (Alias -> Real Sub)
ALIAS_TO_REAL = {v: k for k, v in MAPPING.items()}
# Load Decoder
print("๐Ÿค– Loading RoBERTa-GoEmotions...")
MODEL_NAME = "SamLowe/roberta-base-go_emotions"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
classifier = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
classifier.eval()
id2label = classifier.config.id2label
# --- 2. LOGIC FUNCTIONS ---
def get_sentence_options(subject_name):
# Return available sentences for the selected subject
choices = DATA[subject_name]['Text']
# Pick a random one as default to encourage exploration
default = random.choice(choices)
return gr.Dropdown(choices=choices, value=default)
def get_warning_status(subject, projector_alias):
"""Checks for Data Leakage"""
clean_alias = projector_alias.split(" ")[1]
source_subject = ALIAS_TO_REAL.get(clean_alias)
if source_subject == subject:
return (
"โš ๏ธ **WARNING: DATA LEAKAGE DETECTED**\n\n"
f"The selected Projector ({projector_alias}) includes data from Subject {subject} in its training set.\n"
"Results will be artificially high (Self-Test). For valid research verification, please select a different Projector."
)
else:
return "โœ… **VALID ZERO-SHOT CONFIGURATION**\n\nTarget Subject was NOT seen during Projector training."
def get_historical_accuracy(subject, projector_alias):
"""Retrieves pre-calculated accuracy"""
try:
acc = MATRIX.loc[projector_alias, subject]
return f"**Historical Compatibility:** {acc}"
except:
return "**Historical Compatibility:** N/A"
def decode_neuro_semantics(subject, projector_alias, text):
# 1. Fetch Data
try:
idx = DATA[subject]['Text'].index(text)
eeg_input = DATA[subject]['X'][idx].reshape(1, -1)
except ValueError:
return pd.DataFrame(), "Error: Data point not found."
# 2. Project (EEG -> Vector)
proj_model = MODELS[projector_alias]
predicted_vector = proj_model.predict(eeg_input)
tensor_vec = torch.tensor(predicted_vector).float()
# 3. Decode (Vector -> Emotions)
with torch.no_grad():
# Brain Path
x = classifier.classifier.dense(tensor_vec.unsqueeze(1))
x = torch.tanh(x)
logits_b = classifier.classifier.out_proj(x)
probs_brain = torch.sigmoid(logits_b).squeeze().numpy()
# Text Path (Ground Truth)
inputs = tokenizer(text, return_tensors="pt")
logits_t = classifier(**inputs).logits
probs_text = torch.sigmoid(logits_t).squeeze().numpy()
# 4. Rank & Format
top3_b = np.argsort(probs_brain)[::-1][:3]
top2_t = np.argsort(probs_text)[::-1][:2]
# Check Match (Top-1 Brain vs Top-2 Text)
brain_top1 = id2label[top3_b[0]]
text_top2 = [id2label[i] for i in top2_t]
match_icon = "โœ…" if brain_top1 in text_top2 else "โŒ"
# Build Result Table for ONE sentence
# We display the probabilities nicely
brain_str = ", ".join([f"{id2label[i]} ({probs_brain[i]:.2f})" for i in top3_b])
text_str = ", ".join([f"{id2label[i]} ({probs_text[i]:.2f})" for i in top2_t])
df = pd.DataFrame([{
"Sentence Stimulus": text,
"Text Ground Truth (Top 2)": text_str,
"Brain Decoding (Top 3)": brain_str,
"Match": match_icon
}])
return df
def run_batch_analysis(subject, projector_alias):
# Runs 5 random samples for robust demo
subject_data = DATA[subject]
total_indices = list(range(len(subject_data['Text'])))
# Sample up to 5 sentences
selected_indices = random.sample(total_indices, min(5, len(total_indices)))
results = []
for idx in selected_indices:
txt = subject_data['Text'][idx]
df = decode_neuro_semantics(subject, projector_alias, txt)
results.append(df)
final_df = pd.concat(results)
# Calculate Batch Accuracy
acc = (final_df["Match"] == "โœ…").mean() * 100
return final_df, f"**Batch Accuracy:** {acc:.1f}%"
# --- 3. UI LAYOUT ---
# Formatted Report Text
REPORT_TEXT = """
### 1. Abstract
This interface demonstrates a **Brain-Computer Interface (BCI)** capable of decoding high-level semantic information directly from non-invasive EEG signals. By aligning biological neural activity with the latent space of Large Language Models (LLMs), we show that it is possible to reconstruct the **emotional sentiment** of a sentence a user is reading, even if the model has **never seen that user's brain data before**.
### 2. The Dataset: ZuCo (Zurich Cognitive Language Processing Corpus)
This project utilizes the **ZuCo 2.0 dataset**, a benchmark for cognitive modeling.
* **Protocol:** Subjects read movie reviews naturally while their brain activity (EEG) and eye movements were recorded.
* **The Challenge:** Unlike synthetic tasks, natural reading involves rapid, complex cognitive processing, making signal decoding significantly harder.
### 3. Methodology: Latent Space Projection
Instead of training a simple classifier to predict "Positive" or "Negative" from brain waves, we employ a **Neuro-Semantic Projector**.
* **The Goal:** To learn a mapping function `f(EEG) โ†’ R^768` that transforms raw brain signals into the high-dimensional embedding space of **RoBERTa**.
* **The Mechanism:** The system projects the EEG signal into a vector. This vector is then fed into a frozen, pre-trained LLM (`roberta-base-go_emotions`) to generate a probability distribution over **28 distinct emotional states** (e.g., *Admiration, Annoyance, Gratitude, Remorse*).
### 4. Experimental Setup: Strict Zero-Shot Evaluation
To ensure scientific rigor, this demo adheres to a **Strict Leave-One-Group-Out** protocol.
* **Disjoint Training:** The "Projectors" available in this demo were trained on a subset of subjects and validated on **completely different subjects**.
* **No Calibration:** The model does not receive any calibration data from the target subject. It must rely on universal neural patterns shared across humans.
### 5. Interpretation of Results
The demo compares two probability distributions for every sentence:
1. **Text Ground Truth:** What the AI model thinks the sentence means based on the text alone.
2. **Brain Prediction:** What the AI model thinks the sentence means based **only** on the user's brain waves.
**Accuracy Metric:** A prediction is considered correct if the **Top-1 Emotion** predicted from the Brain Signal matches either the **#1 or #2 Emotion** predicted from the Text.
"""
with gr.Blocks(theme=gr.themes.Soft(), title="Neuro-Semantic Decoder") as demo:
gr.Markdown("# ๐Ÿง  Neuro-Semantic Alignment: Zero-Shot Decoding")
with gr.Tabs():
# --- TAB 1: INTERACTIVE DEMO ---
with gr.TabItem("๐Ÿ”ฎ Interactive Demo"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### โš™๏ธ Configuration")
# Selectors
sub_dropdown = gr.Dropdown(choices=list(DATA.keys()), value="ZKB", label="Select Target Subject (Data Source)")
proj_dropdown = gr.Dropdown(choices=list(MODELS.keys()), value="Projector A", label="Select Projector (Decoding Model)")
# Dynamic Info Boxes
warning_box = gr.Markdown("โœ… **VALID ZERO-SHOT CONFIGURATION**\n\nTarget Subject was NOT seen during Projector training.")
history_box = gr.Markdown("**Historical Compatibility:** 40.0%")
btn = gr.Button("๐Ÿ”ฎ Run Batch Analysis (5 Samples)", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### ๐Ÿ“Š Decoding Results")
# Output Table
result_table = gr.Dataframe(
headers=["Sentence Stimulus", "Text Ground Truth (Top 2)", "Brain Decoding (Top 3)", "Match"],
wrap=True
)
batch_accuracy_box = gr.Markdown("**Batch Accuracy:** -")
# Interactivity
sub_dropdown.change(fn=get_warning_status, inputs=[sub_dropdown, proj_dropdown], outputs=warning_box)
sub_dropdown.change(fn=get_historical_accuracy, inputs=[sub_dropdown, proj_dropdown], outputs=history_box)
proj_dropdown.change(fn=get_warning_status, inputs=[sub_dropdown, proj_dropdown], outputs=warning_box)
proj_dropdown.change(fn=get_historical_accuracy, inputs=[sub_dropdown, proj_dropdown], outputs=history_box)
# Run
btn.click(
fn=run_batch_analysis,
inputs=[sub_dropdown, proj_dropdown],
outputs=[result_table, batch_accuracy_box]
)
# --- TAB 2: REPORT ---
with gr.TabItem("๐Ÿ“˜ Project Report"):
gr.Markdown(REPORT_TEXT)
if __name__ == "__main__":
demo.launch()