| """ |
| Phase 2: Core Model — EmotionAwareMedicalChatbot (v2 — Prefix-Tuning) |
| |
| Architecture: |
| Patient Query |
| ├─→ Longformer Encoder ──→ context embeddings |
| ├─→ ScispaCy Dep Graph → GCN ──→ syntax-aware features |
| ├─→ Frozen Emotion Model ──→ emotion embedding (7-d) |
| └─→ Cross-Attention Fusion ──→ fused context |
| |
| Fused context compressed into N prefix tokens |
| → [PREFIX | Doctor tokens] fed to BioGPT decoder |
| |
| Key Fix (v2): Uses prefix-tuning instead of encoder_hidden_states, |
| because BioGPT is a decoder-only model without cross-attention. |
| """ |
| import torch |
| import torch.nn as nn |
| import json |
| from transformers import ( |
| AutoTokenizer, |
| AutoModel, |
| AutoModelForSequenceClassification, |
| AutoModelForCausalLM, |
| ) |
|
|
| import sys, os |
| sys.path.insert(0, os.path.dirname(__file__)) |
| from config import ( |
| LONGFORMER_MODEL, |
| EMOTION_MODEL, |
| GENERATOR_MODEL, |
| EMOTION_LABELS, |
| NUM_EMOTIONS, |
| MAX_INPUT_TOKENS, |
| MAX_TARGET_TOKENS, |
| NUM_PREFIX_TOKENS, |
| DEVICE, |
| ) |
|
|
|
|
| |
| |
| |
| class SimpleGCNLayer(nn.Module): |
| """Single-layer Graph Convolution: X' = σ(D^{-1} A X W)""" |
|
|
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.linear = nn.Linear(in_dim, out_dim) |
| self.activation = nn.GELU() |
|
|
| def forward(self, node_features, adj_matrix): |
| """ |
| Args: |
| node_features: (B, N, in_dim) |
| adj_matrix: (B, N, N) binary adjacency |
| Returns: |
| (B, N, out_dim) |
| """ |
| |
| degree = adj_matrix.sum(dim=-1, keepdim=True).clamp(min=1) |
| adj_norm = adj_matrix / degree |
| |
| agg = torch.bmm(adj_norm, node_features) |
| return self.activation(self.linear(agg)) |
|
|
|
|
| class SyntaxGCN(nn.Module): |
| """2-layer GCN for dependency-tree encoding.""" |
|
|
| def __init__(self, input_dim=768, hidden_dim=512, output_dim=256): |
| super().__init__() |
| self.gcn1 = SimpleGCNLayer(input_dim, hidden_dim) |
| self.gcn2 = SimpleGCNLayer(hidden_dim, output_dim) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward(self, node_features, adj_matrix): |
| x = self.gcn1(node_features, adj_matrix) |
| x = self.dropout(x) |
| x = self.gcn2(x, adj_matrix) |
| |
| return x.mean(dim=1) |
|
|
|
|
| |
| |
| |
| class CrossAttentionFusion(nn.Module): |
| """Fuses GCN syntax features with Longformer context via cross-attention.""" |
|
|
| def __init__(self, context_dim=768, syntax_dim=256, heads=8): |
| super().__init__() |
| self.attn = nn.MultiheadAttention( |
| embed_dim=context_dim, |
| num_heads=heads, |
| kdim=syntax_dim, |
| vdim=syntax_dim, |
| batch_first=True, |
| ) |
| self.norm = nn.LayerNorm(context_dim) |
|
|
| def forward(self, context_seq, syntax_vec): |
| """ |
| Args: |
| context_seq: (B, seq_len, 768) from Longformer |
| syntax_vec: (B, 256) from GCN (expanded to seq) |
| """ |
| |
| syntax_kv = syntax_vec.unsqueeze(1) |
| attn_out, _ = self.attn(context_seq, syntax_kv, syntax_kv) |
| return self.norm(context_seq + attn_out) |
|
|
|
|
| |
| |
| |
| class ContextCompressor(nn.Module): |
| """ |
| Compresses a variable-length fused encoder sequence into a fixed |
| number of 'prefix tokens' that are prepended to the decoder input. |
| |
| This is the KEY FIX: BioGPT has no cross-attention, so we inject |
| the patient context directly into its input embedding space. |
| """ |
|
|
| def __init__(self, encoder_dim, decoder_dim, num_prefix_tokens=8): |
| super().__init__() |
| self.num_prefix = num_prefix_tokens |
| |
| self.pool_proj = nn.Sequential( |
| nn.Linear(encoder_dim, decoder_dim * num_prefix_tokens), |
| nn.GELU(), |
| nn.LayerNorm(decoder_dim * num_prefix_tokens), |
| ) |
| self.decoder_dim = decoder_dim |
|
|
| def forward(self, fused_seq): |
| """ |
| Args: |
| fused_seq: (B, S, encoder_dim) from Longformer+GCN+Emotion fusion |
| Returns: |
| prefix_embeds: (B, num_prefix, decoder_dim) |
| """ |
| |
| pooled = fused_seq.mean(dim=1) |
| |
| projected = self.pool_proj(pooled) |
| prefix = projected.view(-1, self.num_prefix, self.decoder_dim) |
| return prefix |
|
|
|
|
| |
| |
| |
| class EmotionAwareMedicalChatbot(nn.Module): |
| """ |
| Full SOTA architecture (v2 — Prefix-Tuning) combining: |
| 1. Clinical-Longformer encoder |
| 2. 2-layer Syntax GCN |
| 3. Frozen emotion classifier |
| 4. Cross-attention fusion |
| 5. Context Compressor → prefix tokens |
| 6. BioGPT generative decoder (prefix-conditioned) |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| |
| self.encoder_tokenizer = AutoTokenizer.from_pretrained(LONGFORMER_MODEL) |
| self.encoder = AutoModel.from_pretrained(LONGFORMER_MODEL) |
|
|
| |
| self.syntax_gcn = SyntaxGCN( |
| input_dim=self.encoder.config.hidden_size, |
| hidden_dim=512, |
| output_dim=256, |
| ) |
|
|
| |
| self.emotion_tokenizer = AutoTokenizer.from_pretrained(EMOTION_MODEL) |
| self.emotion_model = AutoModelForSequenceClassification.from_pretrained( |
| EMOTION_MODEL |
| ) |
| |
| for param in self.emotion_model.parameters(): |
| param.requires_grad = False |
| self.emotion_model.eval() |
|
|
| |
| self.cross_attn = CrossAttentionFusion( |
| context_dim=self.encoder.config.hidden_size, |
| syntax_dim=256, |
| ) |
|
|
| |
| self.emotion_proj = nn.Linear(NUM_EMOTIONS, self.encoder.config.hidden_size) |
|
|
| |
| self.decoder_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL) |
| self.decoder = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL) |
|
|
| |
| encoder_dim = self.encoder.config.hidden_size |
| decoder_dim = self.decoder.config.hidden_size |
| self.context_compressor = ContextCompressor( |
| encoder_dim=encoder_dim, |
| decoder_dim=decoder_dim, |
| num_prefix_tokens=NUM_PREFIX_TOKENS, |
| ) |
|
|
| |
| self.emotion_classifier = nn.Linear(encoder_dim, NUM_EMOTIONS) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def get_emotion_embedding(self, texts): |
| """Returns (B, NUM_EMOTIONS) soft probability vector.""" |
| enc = self.emotion_tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=512, |
| return_tensors="pt", |
| ).to(next(self.encoder.parameters()).device) |
|
|
| logits = self.emotion_model(**enc).logits |
| return torch.softmax(logits, dim=-1) |
|
|
| |
| |
| |
| @staticmethod |
| def build_adjacency(dep_edges_json, seq_len, device): |
| """ |
| Args: |
| dep_edges_json: list of JSON strings, each a list of [head, child, rel] |
| seq_len: max sequence length for padding |
| Returns: |
| adj: (B, seq_len, seq_len) float tensor |
| """ |
| batch_size = len(dep_edges_json) |
| adj = torch.zeros(batch_size, seq_len, seq_len, device=device) |
|
|
| for b, edges_str in enumerate(dep_edges_json): |
| try: |
| edges = json.loads(edges_str) if isinstance(edges_str, str) else edges_str |
| for head, child, _ in edges: |
| if head < seq_len and child < seq_len: |
| adj[b, head, child] = 1.0 |
| adj[b, child, head] = 1.0 |
| except (json.JSONDecodeError, ValueError): |
| pass |
| |
| for i in range(seq_len): |
| adj[b, i, i] = 1.0 |
|
|
| return adj |
|
|
| |
| |
| |
| def encode(self, patient_texts, dep_edges_json): |
| """ |
| Run the full encoder pipeline and return: |
| fused_seq: (B, S, 768) — fused context sequence |
| emotion_probs: (B, 7) — emotion probability vector |
| """ |
| device = next(self.encoder.parameters()).device |
|
|
| |
| enc_inputs = self.encoder_tokenizer( |
| patient_texts, |
| padding=True, |
| truncation=True, |
| max_length=MAX_INPUT_TOKENS, |
| return_tensors="pt", |
| ).to(device) |
|
|
| encoder_out = self.encoder(**enc_inputs) |
| context_seq = encoder_out.last_hidden_state |
|
|
| |
| seq_len = context_seq.size(1) |
| adj = self.build_adjacency(dep_edges_json, seq_len, device) |
| syntax_vec = self.syntax_gcn(context_seq, adj) |
|
|
| |
| fused_seq = self.cross_attn(context_seq, syntax_vec) |
|
|
| |
| emotion_probs = self.get_emotion_embedding(patient_texts) |
| emotion_emb = self.emotion_proj(emotion_probs) |
| |
| fused_seq[:, 0, :] = fused_seq[:, 0, :] + emotion_emb |
|
|
| return fused_seq, emotion_probs |
|
|
| |
| |
| |
| def forward( |
| self, |
| patient_texts, |
| dep_edges_json, |
| target_ids=None, |
| target_attention_mask=None, |
| rag_context_ids=None, |
| rag_context_mask=None, |
| ): |
| """ |
| Prefix-Tuning Forward Pass: |
| 1. Encode patient text → fused context |
| 2. Compress fused context into N prefix tokens |
| 3. Get decoder input embeddings for doctor response |
| 4. Prepend prefix tokens to decoder embeddings |
| 5. Run BioGPT on the concatenated sequence |
| """ |
| device = next(self.encoder.parameters()).device |
|
|
| |
| fused_seq, emotion_probs = self.encode(patient_texts, dep_edges_json) |
|
|
| |
| prefix_embeds = self.context_compressor(fused_seq) |
|
|
| |
| cls_vec = fused_seq[:, 0, :] |
| emotion_logits = self.emotion_classifier(cls_vec) |
|
|
| results = {"emotion_pred": emotion_logits, "emotion_target": emotion_probs} |
|
|
| if target_ids is not None: |
| target_ids = target_ids.to(device) |
|
|
| |
| target_embeds = self.decoder.get_input_embeddings()(target_ids) |
|
|
| |
| inputs_embeds = torch.cat([prefix_embeds, target_embeds], dim=1) |
|
|
| |
| prefix_labels = torch.full( |
| (target_ids.size(0), NUM_PREFIX_TOKENS), |
| -100, |
| dtype=torch.long, |
| device=device, |
| ) |
| labels = torch.cat([prefix_labels, target_ids], dim=1) |
|
|
| |
| prefix_mask = torch.ones( |
| target_ids.size(0), NUM_PREFIX_TOKENS, |
| dtype=torch.long, |
| device=device, |
| ) |
| if target_attention_mask is not None: |
| full_mask = torch.cat([prefix_mask, target_attention_mask.to(device)], dim=1) |
| else: |
| full_mask = torch.cat([ |
| prefix_mask, |
| torch.ones_like(target_ids, device=device), |
| ], dim=1) |
|
|
| |
| decoder_out = self.decoder( |
| inputs_embeds=inputs_embeds, |
| attention_mask=full_mask, |
| labels=labels, |
| ) |
| results["gen_loss"] = decoder_out.loss |
| results["logits"] = decoder_out.logits |
| else: |
| results["gen_loss"] = None |
| results["logits"] = None |
|
|
| |
| emotion_log_probs = torch.log_softmax(emotion_logits, dim=-1) |
| emotion_kl = nn.functional.kl_div( |
| emotion_log_probs, emotion_probs, reduction="batchmean" |
| ) |
| results["emotion_loss"] = emotion_kl |
|
|
| return results |
|
|
| |
| |
| |
| @torch.no_grad() |
| def generate_with_context( |
| self, |
| patient_texts, |
| dep_edges_json, |
| max_new_tokens=128, |
| temperature=0.7, |
| top_p=0.9, |
| do_sample=True, |
| ): |
| """ |
| Full-pipeline generation for inference: |
| 1. Encode patient → fused context |
| 2. Compress → prefix tokens |
| 3. Prepend prefix to BOS token |
| 4. Autoregressively generate response |
| """ |
| device = next(self.encoder.parameters()).device |
|
|
| |
| fused_seq, emotion_probs = self.encode(patient_texts, dep_edges_json) |
|
|
| |
| prefix_embeds = self.context_compressor(fused_seq) |
|
|
| |
| cls_vec = fused_seq[:, 0, :] |
| emotion_logits = self.emotion_classifier(cls_vec) |
| emotion_pred = torch.argmax(emotion_logits, dim=-1) |
|
|
| |
| batch_size = prefix_embeds.size(0) |
| generated_texts = [] |
|
|
| for i in range(batch_size): |
| |
| bos_id = self.decoder_tokenizer.bos_token_id or 2 |
| bos_embed = self.decoder.get_input_embeddings()( |
| torch.tensor([[bos_id]], device=device) |
| ) |
|
|
| |
| start_embeds = torch.cat( |
| [prefix_embeds[i:i+1], bos_embed], dim=1 |
| ) |
|
|
| |
| generated_ids = self.decoder.generate( |
| inputs_embeds=start_embeds, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=1.3, |
| pad_token_id=self.decoder_tokenizer.eos_token_id or 2, |
| ) |
|
|
| |
| text = self.decoder_tokenizer.decode( |
| generated_ids[0][NUM_PREFIX_TOKENS + 1:], |
| skip_special_tokens=True, |
| ) |
| generated_texts.append(text.strip()) |
|
|
| return generated_texts, emotion_pred.cpu().tolist() |
|
|