Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import random | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # --- 1. LOAD ARTIFACTS --- | |
| PKG_PATH = "neuro_semantic_package.pt" | |
| print("๐ System Startup: Loading Artifacts...") | |
| if not os.path.exists(PKG_PATH): | |
| # Fallback for local testing if file isn't in root | |
| POSSIBLE_PATHS = [ | |
| "neuro_semantic_package.pt", | |
| "/content/drive/MyDrive/Brain2Text_Project/demo_research_v2/neuro_semantic_package.pt" | |
| ] | |
| for p in POSSIBLE_PATHS: | |
| if os.path.exists(p): | |
| PKG_PATH = p | |
| break | |
| if not os.path.exists(PKG_PATH): | |
| raise FileNotFoundError(f"CRITICAL: '{PKG_PATH}' missing. Please upload the .pt file.") | |
| # Load the "Black Box" package | |
| # map_location='cpu' ensures it runs on basic HF spaces without GPU if needed | |
| PKG = torch.load(PKG_PATH, map_location="cpu", weights_only=False) | |
| DATA = PKG['data'] | |
| MODELS = PKG['models'] # The Projectors | |
| MATRIX = PKG['matrix'] # Pre-calculated Accuracy Table | |
| MAPPING = PKG['mapping_key'] # Secret Mapping | |
| # Inverse mapping (Alias -> Real Sub) | |
| ALIAS_TO_REAL = {v: k for k, v in MAPPING.items()} | |
| # Load Decoder | |
| print("๐ค Loading RoBERTa-GoEmotions...") | |
| MODEL_NAME = "SamLowe/roberta-base-go_emotions" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| classifier = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| classifier.eval() | |
| id2label = classifier.config.id2label | |
| # --- 2. LOGIC FUNCTIONS --- | |
| def get_sentence_options(subject_name): | |
| # Return available sentences for the selected subject | |
| choices = DATA[subject_name]['Text'] | |
| # Pick a random one as default to encourage exploration | |
| default = random.choice(choices) | |
| return gr.Dropdown(choices=choices, value=default) | |
| def get_warning_status(subject, projector_alias): | |
| """Checks for Data Leakage""" | |
| clean_alias = projector_alias.split(" ")[1] | |
| source_subject = ALIAS_TO_REAL.get(clean_alias) | |
| if source_subject == subject: | |
| return ( | |
| "โ ๏ธ **WARNING: DATA LEAKAGE DETECTED**\n\n" | |
| f"The selected Projector ({projector_alias}) includes data from Subject {subject} in its training set.\n" | |
| "Results will be artificially high (Self-Test). For valid research verification, please select a different Projector." | |
| ) | |
| else: | |
| return "โ **VALID ZERO-SHOT CONFIGURATION**\n\nTarget Subject was NOT seen during Projector training." | |
| def get_historical_accuracy(subject, projector_alias): | |
| """Retrieves pre-calculated accuracy""" | |
| try: | |
| acc = MATRIX.loc[projector_alias, subject] | |
| return f"**Historical Compatibility:** {acc}" | |
| except: | |
| return "**Historical Compatibility:** N/A" | |
| def decode_neuro_semantics(subject, projector_alias, text): | |
| # 1. Fetch Data | |
| try: | |
| idx = DATA[subject]['Text'].index(text) | |
| eeg_input = DATA[subject]['X'][idx].reshape(1, -1) | |
| except ValueError: | |
| return pd.DataFrame(), "Error: Data point not found." | |
| # 2. Project (EEG -> Vector) | |
| proj_model = MODELS[projector_alias] | |
| predicted_vector = proj_model.predict(eeg_input) | |
| tensor_vec = torch.tensor(predicted_vector).float() | |
| # 3. Decode (Vector -> Emotions) | |
| with torch.no_grad(): | |
| # Brain Path | |
| x = classifier.classifier.dense(tensor_vec.unsqueeze(1)) | |
| x = torch.tanh(x) | |
| logits_b = classifier.classifier.out_proj(x) | |
| probs_brain = torch.sigmoid(logits_b).squeeze().numpy() | |
| # Text Path (Ground Truth) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| logits_t = classifier(**inputs).logits | |
| probs_text = torch.sigmoid(logits_t).squeeze().numpy() | |
| # 4. Rank & Format | |
| top3_b = np.argsort(probs_brain)[::-1][:3] | |
| top2_t = np.argsort(probs_text)[::-1][:2] | |
| # Check Match (Top-1 Brain vs Top-2 Text) | |
| brain_top1 = id2label[top3_b[0]] | |
| text_top2 = [id2label[i] for i in top2_t] | |
| match_icon = "โ " if brain_top1 in text_top2 else "โ" | |
| # Build Result Table for ONE sentence | |
| # We display the probabilities nicely | |
| brain_str = ", ".join([f"{id2label[i]} ({probs_brain[i]:.2f})" for i in top3_b]) | |
| text_str = ", ".join([f"{id2label[i]} ({probs_text[i]:.2f})" for i in top2_t]) | |
| df = pd.DataFrame([{ | |
| "Sentence Stimulus": text, | |
| "Text Ground Truth (Top 2)": text_str, | |
| "Brain Decoding (Top 3)": brain_str, | |
| "Match": match_icon | |
| }]) | |
| return df | |
| def run_batch_analysis(subject, projector_alias): | |
| # Runs 5 random samples for robust demo | |
| subject_data = DATA[subject] | |
| total_indices = list(range(len(subject_data['Text']))) | |
| # Sample up to 5 sentences | |
| selected_indices = random.sample(total_indices, min(5, len(total_indices))) | |
| results = [] | |
| for idx in selected_indices: | |
| txt = subject_data['Text'][idx] | |
| df = decode_neuro_semantics(subject, projector_alias, txt) | |
| results.append(df) | |
| final_df = pd.concat(results) | |
| # Calculate Batch Accuracy | |
| acc = (final_df["Match"] == "โ ").mean() * 100 | |
| return final_df, f"**Batch Accuracy:** {acc:.1f}%" | |
| # --- 3. UI LAYOUT --- | |
| # Formatted Report Text | |
| REPORT_TEXT = """ | |
| ### 1. Abstract | |
| This interface demonstrates a **Brain-Computer Interface (BCI)** capable of decoding high-level semantic information directly from non-invasive EEG signals. By aligning biological neural activity with the latent space of Large Language Models (LLMs), we show that it is possible to reconstruct the **emotional sentiment** of a sentence a user is reading, even if the model has **never seen that user's brain data before**. | |
| ### 2. The Dataset: ZuCo (Zurich Cognitive Language Processing Corpus) | |
| This project utilizes the **ZuCo 2.0 dataset**, a benchmark for cognitive modeling. | |
| * **Protocol:** Subjects read movie reviews naturally while their brain activity (EEG) and eye movements were recorded. | |
| * **The Challenge:** Unlike synthetic tasks, natural reading involves rapid, complex cognitive processing, making signal decoding significantly harder. | |
| ### 3. Methodology: Latent Space Projection | |
| Instead of training a simple classifier to predict "Positive" or "Negative" from brain waves, we employ a **Neuro-Semantic Projector**. | |
| * **The Goal:** To learn a mapping function `f(EEG) โ R^768` that transforms raw brain signals into the high-dimensional embedding space of **RoBERTa**. | |
| * **The Mechanism:** The system projects the EEG signal into a vector. This vector is then fed into a frozen, pre-trained LLM (`roberta-base-go_emotions`) to generate a probability distribution over **28 distinct emotional states** (e.g., *Admiration, Annoyance, Gratitude, Remorse*). | |
| ### 4. Experimental Setup: Strict Zero-Shot Evaluation | |
| To ensure scientific rigor, this demo adheres to a **Strict Leave-One-Group-Out** protocol. | |
| * **Disjoint Training:** The "Projectors" available in this demo were trained on a subset of subjects and validated on **completely different subjects**. | |
| * **No Calibration:** The model does not receive any calibration data from the target subject. It must rely on universal neural patterns shared across humans. | |
| ### 5. Interpretation of Results | |
| The demo compares two probability distributions for every sentence: | |
| 1. **Text Ground Truth:** What the AI model thinks the sentence means based on the text alone. | |
| 2. **Brain Prediction:** What the AI model thinks the sentence means based **only** on the user's brain waves. | |
| **Accuracy Metric:** A prediction is considered correct if the **Top-1 Emotion** predicted from the Brain Signal matches either the **#1 or #2 Emotion** predicted from the Text. | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Neuro-Semantic Decoder") as demo: | |
| gr.Markdown("# ๐ง Neuro-Semantic Alignment: Zero-Shot Decoding") | |
| with gr.Tabs(): | |
| # --- TAB 1: INTERACTIVE DEMO --- | |
| with gr.TabItem("๐ฎ Interactive Demo"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### โ๏ธ Configuration") | |
| # Selectors | |
| sub_dropdown = gr.Dropdown(choices=list(DATA.keys()), value="ZKB", label="Select Target Subject (Data Source)") | |
| proj_dropdown = gr.Dropdown(choices=list(MODELS.keys()), value="Projector A", label="Select Projector (Decoding Model)") | |
| # Dynamic Info Boxes | |
| warning_box = gr.Markdown("โ **VALID ZERO-SHOT CONFIGURATION**\n\nTarget Subject was NOT seen during Projector training.") | |
| history_box = gr.Markdown("**Historical Compatibility:** 40.0%") | |
| btn = gr.Button("๐ฎ Run Batch Analysis (5 Samples)", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### ๐ Decoding Results") | |
| # Output Table | |
| result_table = gr.Dataframe( | |
| headers=["Sentence Stimulus", "Text Ground Truth (Top 2)", "Brain Decoding (Top 3)", "Match"], | |
| wrap=True | |
| ) | |
| batch_accuracy_box = gr.Markdown("**Batch Accuracy:** -") | |
| # Interactivity | |
| sub_dropdown.change(fn=get_warning_status, inputs=[sub_dropdown, proj_dropdown], outputs=warning_box) | |
| sub_dropdown.change(fn=get_historical_accuracy, inputs=[sub_dropdown, proj_dropdown], outputs=history_box) | |
| proj_dropdown.change(fn=get_warning_status, inputs=[sub_dropdown, proj_dropdown], outputs=warning_box) | |
| proj_dropdown.change(fn=get_historical_accuracy, inputs=[sub_dropdown, proj_dropdown], outputs=history_box) | |
| # Run | |
| btn.click( | |
| fn=run_batch_analysis, | |
| inputs=[sub_dropdown, proj_dropdown], | |
| outputs=[result_table, batch_accuracy_box] | |
| ) | |
| # --- TAB 2: REPORT --- | |
| with gr.TabItem("๐ Project Report"): | |
| gr.Markdown(REPORT_TEXT) | |
| if __name__ == "__main__": | |
| demo.launch() |