st192011 commited on
Commit
5251234
ยท
verified ยท
1 Parent(s): f235648

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import random
6
+ import os
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+ # --- 1. LOAD ARTIFACTS ---
10
+ PKG_PATH = "neuro_semantic_package.pt"
11
+
12
+ print("๐Ÿš€ System Startup: Loading Artifacts...")
13
+ if not os.path.exists(PKG_PATH):
14
+ # Error handling for the web logs
15
+ raise FileNotFoundError(f"CRITICAL: '{PKG_PATH}' missing. Please upload the .pt file.")
16
+
17
+ # Load the "Black Box" package
18
+ PKG = torch.load(PKG_PATH, map_location="cpu", weights_only=False) # Load to CPU for HF Spaces
19
+ DATA = PKG['data']
20
+ MODELS = PKG['models'] # The Projectors
21
+ MATRIX = PKG['matrix'] # Pre-calculated Accuracy Table
22
+ MAPPING = PKG['mapping_key'] # Secret Mapping
23
+
24
+ # Inverse mapping (Alias -> Real Sub)
25
+ ALIAS_TO_REAL = {v: k for k, v in MAPPING.items()}
26
+
27
+ # Load Decoder
28
+ print("๐Ÿค– Loading RoBERTa-GoEmotions...")
29
+ MODEL_NAME = "SamLowe/roberta-base-go_emotions"
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
+ classifier = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
32
+ classifier.eval()
33
+ id2label = classifier.config.id2label
34
+
35
+ # --- 2. LOGIC FUNCTIONS ---
36
+
37
+ def get_sentence_options(subject_name):
38
+ # Return available sentences for the selected subject
39
+ choices = DATA[subject_name]['Text']
40
+ # Pick a random one as default to encourage exploration
41
+ default = random.choice(choices)
42
+ return gr.Dropdown(choices=choices, value=default)
43
+
44
+ def get_warning_status(subject, projector_alias):
45
+ """Checks for Data Leakage"""
46
+ clean_alias = projector_alias.split(" ")[1]
47
+ source_subject = ALIAS_TO_REAL.get(clean_alias)
48
+
49
+ if source_subject == subject:
50
+ return (
51
+ "โš ๏ธ **WARNING: DATA LEAKAGE DETECTED**\n\n"
52
+ f"The selected Projector ({projector_alias}) includes data from Subject {subject} in its training set.\n"
53
+ "Results will be artificially high (Self-Test). For valid research verification, please select a different Projector."
54
+ )
55
+ else:
56
+ return "โœ… **VALID ZERO-SHOT CONFIGURATION**\n\nTarget Subject was NOT seen during Projector training."
57
+
58
+ def get_historical_accuracy(subject, projector_alias):
59
+ """Retrieves pre-calculated accuracy"""
60
+ try:
61
+ acc = MATRIX.loc[projector_alias, subject]
62
+ return f"**Historical Compatibility:** {acc}"
63
+ except:
64
+ return "**Historical Compatibility:** N/A"
65
+
66
+ def decode_neuro_semantics(subject, projector_alias, text):
67
+ # 1. Fetch Data
68
+ try:
69
+ idx = DATA[subject]['Text'].index(text)
70
+ eeg_input = DATA[subject]['X'][idx].reshape(1, -1)
71
+ except ValueError:
72
+ return pd.DataFrame(), "Error: Data point not found."
73
+
74
+ # 2. Project (EEG -> Vector)
75
+ proj_model = MODELS[projector_alias]
76
+ predicted_vector = proj_model.predict(eeg_input)
77
+ tensor_vec = torch.tensor(predicted_vector).float()
78
+
79
+ # 3. Decode (Vector -> Emotions)
80
+ with torch.no_grad():
81
+ # Brain Path
82
+ x = classifier.classifier.dense(tensor_vec.unsqueeze(1))
83
+ x = torch.tanh(x)
84
+ logits_b = classifier.classifier.out_proj(x)
85
+ probs_brain = torch.sigmoid(logits_b).squeeze().numpy()
86
+
87
+ # Text Path (Ground Truth)
88
+ inputs = tokenizer(text, return_tensors="pt")
89
+ logits_t = classifier(**inputs).logits
90
+ probs_text = torch.sigmoid(logits_t).squeeze().numpy()
91
+
92
+ # 4. Rank & Format
93
+ top3_b = np.argsort(probs_brain)[::-1][:3]
94
+ top2_t = np.argsort(probs_text)[::-1][:2]
95
+
96
+ # Check Match (Top-1 Brain vs Top-2 Text)
97
+ brain_top1 = id2label[top3_b[0]]
98
+ text_top2 = [id2label[i] for i in top2_t]
99
+
100
+ match_icon = "โœ…" if brain_top1 in text_top2 else "โŒ"
101
+
102
+ # Build Result Table for ONE sentence
103
+ # We display the probabilities nicely
104
+ brain_str = ", ".join([f"{id2label[i]} ({probs_brain[i]:.2f})" for i in top3_b])
105
+ text_str = ", ".join([f"{id2label[i]} ({probs_text[i]:.2f})" for i in top2_t])
106
+
107
+ df = pd.DataFrame([{
108
+ "Sentence Stimulus": text,
109
+ "Text Ground Truth (Top 2)": text_str,
110
+ "Brain Decoding (Top 3)": brain_str,
111
+ "Match": match_icon
112
+ }])
113
+
114
+ return df, f"**Prediction Status:** {match_icon}"
115
+
116
+ def run_batch_analysis(subject, projector_alias):
117
+ # Runs 5 random samples for robust demo
118
+ subject_data = DATA[subject]
119
+ total_indices = list(range(len(subject_data['Text'])))
120
+ selected_indices = random.sample(total_indices, min(5, len(total_indices)))
121
+
122
+ results = []
123
+
124
+ for idx in selected_indices:
125
+ txt = subject_data['Text'][idx]
126
+ df, stat = decode_neuro_semantics(subject, projector_alias, txt)
127
+ results.append(df)
128
+
129
+ final_df = pd.concat(results)
130
+
131
+ # Calculate Batch Accuracy
132
+ acc = (final_df["Match"] == "โœ…").mean() * 100
133
+ return final_df, f"**Batch Accuracy:** {acc:.1f}%"
134
+
135
+ # --- 3. UI LAYOUT ---
136
+
137
+ INTRODUCTION = """
138
+ ### ๐Ÿ”ฌ Abstract & Methodology
139
+ **Goal:** Zero-Shot decoding of emotional sentiment from raw EEG signals.
140
+
141
+ **Methodology:**
142
+ 1. **Input:** EEG signals from the ZuCo 2.0 dataset (Movie Reviews).
143
+ 2. **Projection:** A Ridge Regression model maps EEG features ($f(EEG)$) to the **RoBERTa-GoEmotions** latent space ($\mathbb{R}^{768}$).
144
+ 3. **Inference:** The projected vector is classified by the frozen RoBERTa head to recover the sentiment probability distribution.
145
+
146
+ **Evaluation Metric:** A prediction is correct if the **Top-1 Brain Prediction** appears within the **Top-2 Text Predictions**.
147
+ """
148
+
149
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
+ gr.Markdown("# ๐Ÿง  Neuro-Semantic Alignment: Zero-Shot Decoding")
151
+
152
+ with gr.Accordion("๐Ÿ“˜ Read Project Report (Abstract & Methodology)", open=False):
153
+ gr.Markdown(INTRODUCTION)
154
+
155
+ with gr.Row():
156
+ with gr.Column(scale=1):
157
+ gr.Markdown("### โš™๏ธ Configuration")
158
+
159
+ # Selectors
160
+ sub_dropdown = gr.Dropdown(choices=list(DATA.keys()), value="ZKB", label="Select Target Subject (Data Source)")
161
+ proj_dropdown = gr.Dropdown(choices=list(MODELS.keys()), value="Projector A", label="Select Projector (Decoding Model)")
162
+
163
+ # Dynamic Info Boxes
164
+ warning_box = gr.Markdown("โœ… **VALID ZERO-SHOT CONFIGURATION**\n\nTarget Subject was NOT seen during Projector training.")
165
+ history_box = gr.Markdown("**Historical Compatibility:** 40.0%")
166
+
167
+ btn = gr.Button("๐Ÿ”ฎ Run Batch Analysis (5 Samples)", variant="primary")
168
+
169
+ with gr.Column(scale=2):
170
+ gr.Markdown("### ๐Ÿ“Š Decoding Results")
171
+
172
+ # Output Table
173
+ result_table = gr.Dataframe(
174
+ headers=["Sentence Stimulus", "Text Ground Truth (Top 2)", "Brain Decoding (Top 3)", "Match"],
175
+ wrap=True
176
+ )
177
+ batch_accuracy_box = gr.Markdown("**Batch Accuracy:** -")
178
+
179
+ # Interactivity
180
+ sub_dropdown.change(fn=get_warning_status, inputs=[sub_dropdown, proj_dropdown], outputs=warning_box)
181
+ sub_dropdown.change(fn=get_historical_accuracy, inputs=[sub_dropdown, proj_dropdown], outputs=history_box)
182
+
183
+ proj_dropdown.change(fn=get_warning_status, inputs=[sub_dropdown, proj_dropdown], outputs=warning_box)
184
+ proj_dropdown.change(fn=get_historical_accuracy, inputs=[sub_dropdown, proj_dropdown], outputs=history_box)
185
+
186
+ # Run
187
+ btn.click(
188
+ fn=run_batch_analysis,
189
+ inputs=[sub_dropdown, proj_dropdown],
190
+ outputs=[result_table, batch_accuracy_box]
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ demo.launch()