shreenikethjoshi's picture
Create model.py
10927d7 verified
"""
Phase 2: Core Model — EmotionAwareMedicalChatbot (v2 — Prefix-Tuning)
Architecture:
Patient Query
├─→ Longformer Encoder ──→ context embeddings
├─→ ScispaCy Dep Graph → GCN ──→ syntax-aware features
├─→ Frozen Emotion Model ──→ emotion embedding (7-d)
└─→ Cross-Attention Fusion ──→ fused context
Fused context compressed into N prefix tokens
→ [PREFIX | Doctor tokens] fed to BioGPT decoder
Key Fix (v2): Uses prefix-tuning instead of encoder_hidden_states,
because BioGPT is a decoder-only model without cross-attention.
"""
import torch
import torch.nn as nn
import json
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
)
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
from config import (
LONGFORMER_MODEL,
EMOTION_MODEL,
GENERATOR_MODEL,
EMOTION_LABELS,
NUM_EMOTIONS,
MAX_INPUT_TOKENS,
MAX_TARGET_TOKENS,
NUM_PREFIX_TOKENS,
DEVICE,
)
# ============================================================
# GCN Layer (Lightweight, no DGL dependency at inference)
# ============================================================
class SimpleGCNLayer(nn.Module):
"""Single-layer Graph Convolution: X' = σ(D^{-1} A X W)"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
self.activation = nn.GELU()
def forward(self, node_features, adj_matrix):
"""
Args:
node_features: (B, N, in_dim)
adj_matrix: (B, N, N) binary adjacency
Returns:
(B, N, out_dim)
"""
# Degree normalization
degree = adj_matrix.sum(dim=-1, keepdim=True).clamp(min=1)
adj_norm = adj_matrix / degree
# Message passing
agg = torch.bmm(adj_norm, node_features) # (B, N, in_dim)
return self.activation(self.linear(agg))
class SyntaxGCN(nn.Module):
"""2-layer GCN for dependency-tree encoding."""
def __init__(self, input_dim=768, hidden_dim=512, output_dim=256):
super().__init__()
self.gcn1 = SimpleGCNLayer(input_dim, hidden_dim)
self.gcn2 = SimpleGCNLayer(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.1)
def forward(self, node_features, adj_matrix):
x = self.gcn1(node_features, adj_matrix)
x = self.dropout(x)
x = self.gcn2(x, adj_matrix)
# Global graph readout: mean pool over nodes
return x.mean(dim=1) # (B, output_dim)
# ============================================================
# Cross-Attention Fusion
# ============================================================
class CrossAttentionFusion(nn.Module):
"""Fuses GCN syntax features with Longformer context via cross-attention."""
def __init__(self, context_dim=768, syntax_dim=256, heads=8):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=context_dim,
num_heads=heads,
kdim=syntax_dim,
vdim=syntax_dim,
batch_first=True,
)
self.norm = nn.LayerNorm(context_dim)
def forward(self, context_seq, syntax_vec):
"""
Args:
context_seq: (B, seq_len, 768) from Longformer
syntax_vec: (B, 256) from GCN (expanded to seq)
"""
# Expand syntax to a single-token KV
syntax_kv = syntax_vec.unsqueeze(1) # (B, 1, 256)
attn_out, _ = self.attn(context_seq, syntax_kv, syntax_kv)
return self.norm(context_seq + attn_out) # (B, seq_len, 768)
# ============================================================
# Context Compressor (Prefix-Tuning)
# ============================================================
class ContextCompressor(nn.Module):
"""
Compresses a variable-length fused encoder sequence into a fixed
number of 'prefix tokens' that are prepended to the decoder input.
This is the KEY FIX: BioGPT has no cross-attention, so we inject
the patient context directly into its input embedding space.
"""
def __init__(self, encoder_dim, decoder_dim, num_prefix_tokens=8):
super().__init__()
self.num_prefix = num_prefix_tokens
# Pool + project encoder sequence → N prefix embeddings
self.pool_proj = nn.Sequential(
nn.Linear(encoder_dim, decoder_dim * num_prefix_tokens),
nn.GELU(),
nn.LayerNorm(decoder_dim * num_prefix_tokens),
)
self.decoder_dim = decoder_dim
def forward(self, fused_seq):
"""
Args:
fused_seq: (B, S, encoder_dim) from Longformer+GCN+Emotion fusion
Returns:
prefix_embeds: (B, num_prefix, decoder_dim)
"""
# Mean-pool across sequence dimension
pooled = fused_seq.mean(dim=1) # (B, encoder_dim)
# Project to N * decoder_dim, then reshape
projected = self.pool_proj(pooled) # (B, N * decoder_dim)
prefix = projected.view(-1, self.num_prefix, self.decoder_dim) # (B, N, decoder_dim)
return prefix
# ============================================================
# Main Model
# ============================================================
class EmotionAwareMedicalChatbot(nn.Module):
"""
Full SOTA architecture (v2 — Prefix-Tuning) combining:
1. Clinical-Longformer encoder
2. 2-layer Syntax GCN
3. Frozen emotion classifier
4. Cross-attention fusion
5. Context Compressor → prefix tokens
6. BioGPT generative decoder (prefix-conditioned)
"""
def __init__(self):
super().__init__()
# --- Longformer Encoder ---
self.encoder_tokenizer = AutoTokenizer.from_pretrained(LONGFORMER_MODEL)
self.encoder = AutoModel.from_pretrained(LONGFORMER_MODEL)
# --- Syntax GCN ---
self.syntax_gcn = SyntaxGCN(
input_dim=self.encoder.config.hidden_size,
hidden_dim=512,
output_dim=256,
)
# --- Frozen Emotion Model ---
self.emotion_tokenizer = AutoTokenizer.from_pretrained(EMOTION_MODEL)
self.emotion_model = AutoModelForSequenceClassification.from_pretrained(
EMOTION_MODEL
)
# Freeze completely
for param in self.emotion_model.parameters():
param.requires_grad = False
self.emotion_model.eval()
# --- Cross-Attention Fusion ---
self.cross_attn = CrossAttentionFusion(
context_dim=self.encoder.config.hidden_size,
syntax_dim=256,
)
# --- Emotion Projection ---
self.emotion_proj = nn.Linear(NUM_EMOTIONS, self.encoder.config.hidden_size)
# --- Generative Decoder (BioGPT) ---
self.decoder_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
self.decoder = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL)
# --- Context Compressor (Prefix-Tuning) ---
encoder_dim = self.encoder.config.hidden_size
decoder_dim = self.decoder.config.hidden_size
self.context_compressor = ContextCompressor(
encoder_dim=encoder_dim,
decoder_dim=decoder_dim,
num_prefix_tokens=NUM_PREFIX_TOKENS,
)
# --- Auxiliary Emotion Classifier (for multi-task loss) ---
self.emotion_classifier = nn.Linear(encoder_dim, NUM_EMOTIONS)
# ----------------------------------------------------------
# Emotion extraction (frozen, no grad)
# ----------------------------------------------------------
@torch.no_grad()
def get_emotion_embedding(self, texts):
"""Returns (B, NUM_EMOTIONS) soft probability vector."""
enc = self.emotion_tokenizer(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(next(self.encoder.parameters()).device)
logits = self.emotion_model(**enc).logits
return torch.softmax(logits, dim=-1) # (B, 7)
# ----------------------------------------------------------
# Build adjacency matrix from dependency edges
# ----------------------------------------------------------
@staticmethod
def build_adjacency(dep_edges_json, seq_len, device):
"""
Args:
dep_edges_json: list of JSON strings, each a list of [head, child, rel]
seq_len: max sequence length for padding
Returns:
adj: (B, seq_len, seq_len) float tensor
"""
batch_size = len(dep_edges_json)
adj = torch.zeros(batch_size, seq_len, seq_len, device=device)
for b, edges_str in enumerate(dep_edges_json):
try:
edges = json.loads(edges_str) if isinstance(edges_str, str) else edges_str
for head, child, _ in edges:
if head < seq_len and child < seq_len:
adj[b, head, child] = 1.0
adj[b, child, head] = 1.0 # undirected
except (json.JSONDecodeError, ValueError):
pass
# Add self-loops
for i in range(seq_len):
adj[b, i, i] = 1.0
return adj
# ----------------------------------------------------------
# Encode: Full pipeline (Longformer → GCN → Fusion → Emotion)
# ----------------------------------------------------------
def encode(self, patient_texts, dep_edges_json):
"""
Run the full encoder pipeline and return:
fused_seq: (B, S, 768) — fused context sequence
emotion_probs: (B, 7) — emotion probability vector
"""
device = next(self.encoder.parameters()).device
# 1. Encode patient dialogue with Longformer
enc_inputs = self.encoder_tokenizer(
patient_texts,
padding=True,
truncation=True,
max_length=MAX_INPUT_TOKENS,
return_tensors="pt",
).to(device)
encoder_out = self.encoder(**enc_inputs)
context_seq = encoder_out.last_hidden_state # (B, S, 768)
# 2. Build adjacency and run GCN
seq_len = context_seq.size(1)
adj = self.build_adjacency(dep_edges_json, seq_len, device)
syntax_vec = self.syntax_gcn(context_seq, adj) # (B, 256)
# 3. Cross-attention fusion (context + syntax)
fused_seq = self.cross_attn(context_seq, syntax_vec) # (B, S, 768)
# 4. Emotion embedding
emotion_probs = self.get_emotion_embedding(patient_texts) # (B, 7)
emotion_emb = self.emotion_proj(emotion_probs) # (B, 768)
# Add emotion signal to the CLS token position
fused_seq[:, 0, :] = fused_seq[:, 0, :] + emotion_emb
return fused_seq, emotion_probs
# ----------------------------------------------------------
# Forward Pass (Training — with teacher forcing)
# ----------------------------------------------------------
def forward(
self,
patient_texts,
dep_edges_json,
target_ids=None,
target_attention_mask=None,
rag_context_ids=None,
rag_context_mask=None,
):
"""
Prefix-Tuning Forward Pass:
1. Encode patient text → fused context
2. Compress fused context into N prefix tokens
3. Get decoder input embeddings for doctor response
4. Prepend prefix tokens to decoder embeddings
5. Run BioGPT on the concatenated sequence
"""
device = next(self.encoder.parameters()).device
# === ENCODE ===
fused_seq, emotion_probs = self.encode(patient_texts, dep_edges_json)
# === COMPRESS → PREFIX TOKENS ===
prefix_embeds = self.context_compressor(fused_seq) # (B, N, decoder_dim)
# === AUXILIARY EMOTION PREDICTION ===
cls_vec = fused_seq[:, 0, :] # (B, 768)
emotion_logits = self.emotion_classifier(cls_vec) # (B, 7)
results = {"emotion_pred": emotion_logits, "emotion_target": emotion_probs}
if target_ids is not None:
target_ids = target_ids.to(device)
# Get decoder's own word embeddings for the target
target_embeds = self.decoder.get_input_embeddings()(target_ids) # (B, T, decoder_dim)
# Prepend prefix: [PREFIX_1..PREFIX_N | target_1..target_T]
inputs_embeds = torch.cat([prefix_embeds, target_embeds], dim=1) # (B, N+T, decoder_dim)
# Build labels: -100 for prefix positions (don't compute loss there)
prefix_labels = torch.full(
(target_ids.size(0), NUM_PREFIX_TOKENS),
-100,
dtype=torch.long,
device=device,
)
labels = torch.cat([prefix_labels, target_ids], dim=1) # (B, N+T)
# Build attention mask
prefix_mask = torch.ones(
target_ids.size(0), NUM_PREFIX_TOKENS,
dtype=torch.long,
device=device,
)
if target_attention_mask is not None:
full_mask = torch.cat([prefix_mask, target_attention_mask.to(device)], dim=1)
else:
full_mask = torch.cat([
prefix_mask,
torch.ones_like(target_ids, device=device),
], dim=1)
# Run the decoder with the PREPENDED context
decoder_out = self.decoder(
inputs_embeds=inputs_embeds,
attention_mask=full_mask,
labels=labels,
)
results["gen_loss"] = decoder_out.loss
results["logits"] = decoder_out.logits
else:
results["gen_loss"] = None
results["logits"] = None
# Emotion auxiliary loss (KL divergence)
emotion_log_probs = torch.log_softmax(emotion_logits, dim=-1)
emotion_kl = nn.functional.kl_div(
emotion_log_probs, emotion_probs, reduction="batchmean"
)
results["emotion_loss"] = emotion_kl
return results
# ----------------------------------------------------------
# Generate (Inference — used by evaluate.py and app.py)
# ----------------------------------------------------------
@torch.no_grad()
def generate_with_context(
self,
patient_texts,
dep_edges_json,
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
do_sample=True,
):
"""
Full-pipeline generation for inference:
1. Encode patient → fused context
2. Compress → prefix tokens
3. Prepend prefix to BOS token
4. Autoregressively generate response
"""
device = next(self.encoder.parameters()).device
# === ENCODE ===
fused_seq, emotion_probs = self.encode(patient_texts, dep_edges_json)
# === COMPRESS → PREFIX ===
prefix_embeds = self.context_compressor(fused_seq) # (B, N, decoder_dim)
# === EMOTION PREDICTION ===
cls_vec = fused_seq[:, 0, :]
emotion_logits = self.emotion_classifier(cls_vec)
emotion_pred = torch.argmax(emotion_logits, dim=-1)
# === GENERATE ===
batch_size = prefix_embeds.size(0)
generated_texts = []
for i in range(batch_size):
# Start with BOS token
bos_id = self.decoder_tokenizer.bos_token_id or 2
bos_embed = self.decoder.get_input_embeddings()(
torch.tensor([[bos_id]], device=device)
) # (1, 1, decoder_dim)
# Prepend prefix to bos: [PREFIX | BOS]
start_embeds = torch.cat(
[prefix_embeds[i:i+1], bos_embed], dim=1
) # (1, N+1, decoder_dim)
# Generate autoregressively
generated_ids = self.decoder.generate(
inputs_embeds=start_embeds,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.3,
pad_token_id=self.decoder_tokenizer.eos_token_id or 2,
)
# Decode (skip prefix positions in the output)
text = self.decoder_tokenizer.decode(
generated_ids[0][NUM_PREFIX_TOKENS + 1:],
skip_special_tokens=True,
)
generated_texts.append(text.strip())
return generated_texts, emotion_pred.cpu().tolist()