""" Phase 2: Core Model — EmotionAwareMedicalChatbot (v2 — Prefix-Tuning) Architecture: Patient Query ├─→ Longformer Encoder ──→ context embeddings ├─→ ScispaCy Dep Graph → GCN ──→ syntax-aware features ├─→ Frozen Emotion Model ──→ emotion embedding (7-d) └─→ Cross-Attention Fusion ──→ fused context Fused context compressed into N prefix tokens → [PREFIX | Doctor tokens] fed to BioGPT decoder Key Fix (v2): Uses prefix-tuning instead of encoder_hidden_states, because BioGPT is a decoder-only model without cross-attention. """ import torch import torch.nn as nn import json from transformers import ( AutoTokenizer, AutoModel, AutoModelForSequenceClassification, AutoModelForCausalLM, ) import sys, os sys.path.insert(0, os.path.dirname(__file__)) from config import ( LONGFORMER_MODEL, EMOTION_MODEL, GENERATOR_MODEL, EMOTION_LABELS, NUM_EMOTIONS, MAX_INPUT_TOKENS, MAX_TARGET_TOKENS, NUM_PREFIX_TOKENS, DEVICE, ) # ============================================================ # GCN Layer (Lightweight, no DGL dependency at inference) # ============================================================ class SimpleGCNLayer(nn.Module): """Single-layer Graph Convolution: X' = σ(D^{-1} A X W)""" def __init__(self, in_dim, out_dim): super().__init__() self.linear = nn.Linear(in_dim, out_dim) self.activation = nn.GELU() def forward(self, node_features, adj_matrix): """ Args: node_features: (B, N, in_dim) adj_matrix: (B, N, N) binary adjacency Returns: (B, N, out_dim) """ # Degree normalization degree = adj_matrix.sum(dim=-1, keepdim=True).clamp(min=1) adj_norm = adj_matrix / degree # Message passing agg = torch.bmm(adj_norm, node_features) # (B, N, in_dim) return self.activation(self.linear(agg)) class SyntaxGCN(nn.Module): """2-layer GCN for dependency-tree encoding.""" def __init__(self, input_dim=768, hidden_dim=512, output_dim=256): super().__init__() self.gcn1 = SimpleGCNLayer(input_dim, hidden_dim) self.gcn2 = SimpleGCNLayer(hidden_dim, output_dim) self.dropout = nn.Dropout(0.1) def forward(self, node_features, adj_matrix): x = self.gcn1(node_features, adj_matrix) x = self.dropout(x) x = self.gcn2(x, adj_matrix) # Global graph readout: mean pool over nodes return x.mean(dim=1) # (B, output_dim) # ============================================================ # Cross-Attention Fusion # ============================================================ class CrossAttentionFusion(nn.Module): """Fuses GCN syntax features with Longformer context via cross-attention.""" def __init__(self, context_dim=768, syntax_dim=256, heads=8): super().__init__() self.attn = nn.MultiheadAttention( embed_dim=context_dim, num_heads=heads, kdim=syntax_dim, vdim=syntax_dim, batch_first=True, ) self.norm = nn.LayerNorm(context_dim) def forward(self, context_seq, syntax_vec): """ Args: context_seq: (B, seq_len, 768) from Longformer syntax_vec: (B, 256) from GCN (expanded to seq) """ # Expand syntax to a single-token KV syntax_kv = syntax_vec.unsqueeze(1) # (B, 1, 256) attn_out, _ = self.attn(context_seq, syntax_kv, syntax_kv) return self.norm(context_seq + attn_out) # (B, seq_len, 768) # ============================================================ # Context Compressor (Prefix-Tuning) # ============================================================ class ContextCompressor(nn.Module): """ Compresses a variable-length fused encoder sequence into a fixed number of 'prefix tokens' that are prepended to the decoder input. This is the KEY FIX: BioGPT has no cross-attention, so we inject the patient context directly into its input embedding space. """ def __init__(self, encoder_dim, decoder_dim, num_prefix_tokens=8): super().__init__() self.num_prefix = num_prefix_tokens # Pool + project encoder sequence → N prefix embeddings self.pool_proj = nn.Sequential( nn.Linear(encoder_dim, decoder_dim * num_prefix_tokens), nn.GELU(), nn.LayerNorm(decoder_dim * num_prefix_tokens), ) self.decoder_dim = decoder_dim def forward(self, fused_seq): """ Args: fused_seq: (B, S, encoder_dim) from Longformer+GCN+Emotion fusion Returns: prefix_embeds: (B, num_prefix, decoder_dim) """ # Mean-pool across sequence dimension pooled = fused_seq.mean(dim=1) # (B, encoder_dim) # Project to N * decoder_dim, then reshape projected = self.pool_proj(pooled) # (B, N * decoder_dim) prefix = projected.view(-1, self.num_prefix, self.decoder_dim) # (B, N, decoder_dim) return prefix # ============================================================ # Main Model # ============================================================ class EmotionAwareMedicalChatbot(nn.Module): """ Full SOTA architecture (v2 — Prefix-Tuning) combining: 1. Clinical-Longformer encoder 2. 2-layer Syntax GCN 3. Frozen emotion classifier 4. Cross-attention fusion 5. Context Compressor → prefix tokens 6. BioGPT generative decoder (prefix-conditioned) """ def __init__(self): super().__init__() # --- Longformer Encoder --- self.encoder_tokenizer = AutoTokenizer.from_pretrained(LONGFORMER_MODEL) self.encoder = AutoModel.from_pretrained(LONGFORMER_MODEL) # --- Syntax GCN --- self.syntax_gcn = SyntaxGCN( input_dim=self.encoder.config.hidden_size, hidden_dim=512, output_dim=256, ) # --- Frozen Emotion Model --- self.emotion_tokenizer = AutoTokenizer.from_pretrained(EMOTION_MODEL) self.emotion_model = AutoModelForSequenceClassification.from_pretrained( EMOTION_MODEL ) # Freeze completely for param in self.emotion_model.parameters(): param.requires_grad = False self.emotion_model.eval() # --- Cross-Attention Fusion --- self.cross_attn = CrossAttentionFusion( context_dim=self.encoder.config.hidden_size, syntax_dim=256, ) # --- Emotion Projection --- self.emotion_proj = nn.Linear(NUM_EMOTIONS, self.encoder.config.hidden_size) # --- Generative Decoder (BioGPT) --- self.decoder_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL) self.decoder = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL) # --- Context Compressor (Prefix-Tuning) --- encoder_dim = self.encoder.config.hidden_size decoder_dim = self.decoder.config.hidden_size self.context_compressor = ContextCompressor( encoder_dim=encoder_dim, decoder_dim=decoder_dim, num_prefix_tokens=NUM_PREFIX_TOKENS, ) # --- Auxiliary Emotion Classifier (for multi-task loss) --- self.emotion_classifier = nn.Linear(encoder_dim, NUM_EMOTIONS) # ---------------------------------------------------------- # Emotion extraction (frozen, no grad) # ---------------------------------------------------------- @torch.no_grad() def get_emotion_embedding(self, texts): """Returns (B, NUM_EMOTIONS) soft probability vector.""" enc = self.emotion_tokenizer( texts, padding=True, truncation=True, max_length=512, return_tensors="pt", ).to(next(self.encoder.parameters()).device) logits = self.emotion_model(**enc).logits return torch.softmax(logits, dim=-1) # (B, 7) # ---------------------------------------------------------- # Build adjacency matrix from dependency edges # ---------------------------------------------------------- @staticmethod def build_adjacency(dep_edges_json, seq_len, device): """ Args: dep_edges_json: list of JSON strings, each a list of [head, child, rel] seq_len: max sequence length for padding Returns: adj: (B, seq_len, seq_len) float tensor """ batch_size = len(dep_edges_json) adj = torch.zeros(batch_size, seq_len, seq_len, device=device) for b, edges_str in enumerate(dep_edges_json): try: edges = json.loads(edges_str) if isinstance(edges_str, str) else edges_str for head, child, _ in edges: if head < seq_len and child < seq_len: adj[b, head, child] = 1.0 adj[b, child, head] = 1.0 # undirected except (json.JSONDecodeError, ValueError): pass # Add self-loops for i in range(seq_len): adj[b, i, i] = 1.0 return adj # ---------------------------------------------------------- # Encode: Full pipeline (Longformer → GCN → Fusion → Emotion) # ---------------------------------------------------------- def encode(self, patient_texts, dep_edges_json): """ Run the full encoder pipeline and return: fused_seq: (B, S, 768) — fused context sequence emotion_probs: (B, 7) — emotion probability vector """ device = next(self.encoder.parameters()).device # 1. Encode patient dialogue with Longformer enc_inputs = self.encoder_tokenizer( patient_texts, padding=True, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt", ).to(device) encoder_out = self.encoder(**enc_inputs) context_seq = encoder_out.last_hidden_state # (B, S, 768) # 2. Build adjacency and run GCN seq_len = context_seq.size(1) adj = self.build_adjacency(dep_edges_json, seq_len, device) syntax_vec = self.syntax_gcn(context_seq, adj) # (B, 256) # 3. Cross-attention fusion (context + syntax) fused_seq = self.cross_attn(context_seq, syntax_vec) # (B, S, 768) # 4. Emotion embedding emotion_probs = self.get_emotion_embedding(patient_texts) # (B, 7) emotion_emb = self.emotion_proj(emotion_probs) # (B, 768) # Add emotion signal to the CLS token position fused_seq[:, 0, :] = fused_seq[:, 0, :] + emotion_emb return fused_seq, emotion_probs # ---------------------------------------------------------- # Forward Pass (Training — with teacher forcing) # ---------------------------------------------------------- def forward( self, patient_texts, dep_edges_json, target_ids=None, target_attention_mask=None, rag_context_ids=None, rag_context_mask=None, ): """ Prefix-Tuning Forward Pass: 1. Encode patient text → fused context 2. Compress fused context into N prefix tokens 3. Get decoder input embeddings for doctor response 4. Prepend prefix tokens to decoder embeddings 5. Run BioGPT on the concatenated sequence """ device = next(self.encoder.parameters()).device # === ENCODE === fused_seq, emotion_probs = self.encode(patient_texts, dep_edges_json) # === COMPRESS → PREFIX TOKENS === prefix_embeds = self.context_compressor(fused_seq) # (B, N, decoder_dim) # === AUXILIARY EMOTION PREDICTION === cls_vec = fused_seq[:, 0, :] # (B, 768) emotion_logits = self.emotion_classifier(cls_vec) # (B, 7) results = {"emotion_pred": emotion_logits, "emotion_target": emotion_probs} if target_ids is not None: target_ids = target_ids.to(device) # Get decoder's own word embeddings for the target target_embeds = self.decoder.get_input_embeddings()(target_ids) # (B, T, decoder_dim) # Prepend prefix: [PREFIX_1..PREFIX_N | target_1..target_T] inputs_embeds = torch.cat([prefix_embeds, target_embeds], dim=1) # (B, N+T, decoder_dim) # Build labels: -100 for prefix positions (don't compute loss there) prefix_labels = torch.full( (target_ids.size(0), NUM_PREFIX_TOKENS), -100, dtype=torch.long, device=device, ) labels = torch.cat([prefix_labels, target_ids], dim=1) # (B, N+T) # Build attention mask prefix_mask = torch.ones( target_ids.size(0), NUM_PREFIX_TOKENS, dtype=torch.long, device=device, ) if target_attention_mask is not None: full_mask = torch.cat([prefix_mask, target_attention_mask.to(device)], dim=1) else: full_mask = torch.cat([ prefix_mask, torch.ones_like(target_ids, device=device), ], dim=1) # Run the decoder with the PREPENDED context decoder_out = self.decoder( inputs_embeds=inputs_embeds, attention_mask=full_mask, labels=labels, ) results["gen_loss"] = decoder_out.loss results["logits"] = decoder_out.logits else: results["gen_loss"] = None results["logits"] = None # Emotion auxiliary loss (KL divergence) emotion_log_probs = torch.log_softmax(emotion_logits, dim=-1) emotion_kl = nn.functional.kl_div( emotion_log_probs, emotion_probs, reduction="batchmean" ) results["emotion_loss"] = emotion_kl return results # ---------------------------------------------------------- # Generate (Inference — used by evaluate.py and app.py) # ---------------------------------------------------------- @torch.no_grad() def generate_with_context( self, patient_texts, dep_edges_json, max_new_tokens=128, temperature=0.7, top_p=0.9, do_sample=True, ): """ Full-pipeline generation for inference: 1. Encode patient → fused context 2. Compress → prefix tokens 3. Prepend prefix to BOS token 4. Autoregressively generate response """ device = next(self.encoder.parameters()).device # === ENCODE === fused_seq, emotion_probs = self.encode(patient_texts, dep_edges_json) # === COMPRESS → PREFIX === prefix_embeds = self.context_compressor(fused_seq) # (B, N, decoder_dim) # === EMOTION PREDICTION === cls_vec = fused_seq[:, 0, :] emotion_logits = self.emotion_classifier(cls_vec) emotion_pred = torch.argmax(emotion_logits, dim=-1) # === GENERATE === batch_size = prefix_embeds.size(0) generated_texts = [] for i in range(batch_size): # Start with BOS token bos_id = self.decoder_tokenizer.bos_token_id or 2 bos_embed = self.decoder.get_input_embeddings()( torch.tensor([[bos_id]], device=device) ) # (1, 1, decoder_dim) # Prepend prefix to bos: [PREFIX | BOS] start_embeds = torch.cat( [prefix_embeds[i:i+1], bos_embed], dim=1 ) # (1, N+1, decoder_dim) # Generate autoregressively generated_ids = self.decoder.generate( inputs_embeds=start_embeds, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, repetition_penalty=1.3, pad_token_id=self.decoder_tokenizer.eos_token_id or 2, ) # Decode (skip prefix positions in the output) text = self.decoder_tokenizer.decode( generated_ids[0][NUM_PREFIX_TOKENS + 1:], skip_special_tokens=True, ) generated_texts.append(text.strip()) return generated_texts, emotion_pred.cpu().tolist()