unicosys-hypergraph / modeling_unicosys.py
drzo's picture
fix: _tied_weights_keys as dict + post_init() for transformers 5.3.0
ce3e598 verified
"""
Unicosys Hypergraph Knowledge Model
A trainable knowledge graph embedding model that encodes the unified
hypergraph (entities, evidence, transactions, communications) as
learned vector representations.
Load with:
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True)
model = AutoModel.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True)
"""
import json
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_unicosys import UnicosysConfig
# ---------------------------------------------------------------------------
# Text Encoder (lightweight)
# ---------------------------------------------------------------------------
class LightweightTextEncoder(nn.Module):
"""A small transformer encoder for node labels and descriptions."""
def __init__(self, config: UnicosysConfig):
super().__init__()
self.token_embed = nn.Embedding(config.text_vocab_size, config.text_embed_dim)
self.pos_embed = nn.Embedding(config.text_max_length, config.text_embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.text_embed_dim,
nhead=config.text_num_heads,
dim_feedforward=config.text_embed_dim * 4,
dropout=config.gat_dropout,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config.text_num_layers
)
self.pool_proj = nn.Linear(config.text_embed_dim, config.hidden_dim)
def forward(self, input_ids, attention_mask=None):
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
x = self.token_embed(input_ids) + self.pos_embed(positions)
if attention_mask is not None:
src_key_padding_mask = attention_mask == 0
else:
src_key_padding_mask = None
x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
else:
pooled = x.mean(dim=1)
return self.pool_proj(pooled)
# ---------------------------------------------------------------------------
# Graph Attention Layer
# ---------------------------------------------------------------------------
class GraphAttentionLayer(nn.Module):
"""Multi-head graph attention for hypergraph node updates."""
def __init__(self, config: UnicosysConfig):
super().__init__()
self.num_heads = config.gat_num_heads
self.head_dim = config.hidden_dim // config.gat_num_heads
assert self.head_dim * self.num_heads == config.hidden_dim
self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.edge_proj = nn.Linear(config.node_embed_dim, config.hidden_dim)
self.out_proj = nn.Linear(config.hidden_dim, config.hidden_dim)
self.norm = nn.LayerNorm(config.hidden_dim)
self.dropout = nn.Dropout(config.gat_dropout)
def forward(self, node_embeds, edge_index, edge_type_embeds):
N = node_embeds.size(0)
src, tgt = edge_index
q = self.q_proj(node_embeds[tgt])
k = self.k_proj(node_embeds[src])
v = self.v_proj(node_embeds[src])
edge_bias = self.edge_proj(edge_type_embeds)
k = k + edge_bias
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
attn = (q * k).sum(dim=-1) / math.sqrt(self.head_dim)
attn_max = torch.zeros(N, self.num_heads, device=attn.device)
attn_max.scatter_reduce_(0, tgt.unsqueeze(1).expand_as(attn), attn, reduce="amax")
attn = torch.exp(attn - attn_max[tgt])
attn_sum = torch.zeros(N, self.num_heads, device=attn.device)
attn_sum.scatter_add_(0, tgt.unsqueeze(1).expand_as(attn), attn)
attn = attn / attn_sum[tgt].clamp(min=1e-8)
attn = self.dropout(attn)
weighted = v * attn.unsqueeze(-1)
weighted = weighted.view(-1, self.num_heads * self.head_dim)
out = torch.zeros(N, self.num_heads * self.head_dim, device=weighted.device)
out.scatter_add_(0, tgt.unsqueeze(1).expand_as(weighted), weighted)
out = self.out_proj(out)
return self.norm(node_embeds + out)
# ---------------------------------------------------------------------------
# Link Prediction Head
# ---------------------------------------------------------------------------
class LinkPredictionHead(nn.Module):
"""Scores candidate edges for link prediction training."""
def __init__(self, config: UnicosysConfig):
super().__init__()
self.edge_type_embed = nn.Embedding(config.num_edge_types, config.hidden_dim)
self.scorer = nn.Sequential(
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
nn.ReLU(),
nn.Dropout(config.gat_dropout),
nn.Linear(config.hidden_dim, 1),
)
def forward(self, src_embeds, tgt_embeds, edge_type_ids):
edge_embeds = self.edge_type_embed(edge_type_ids)
combined = torch.cat([src_embeds, tgt_embeds, edge_embeds], dim=-1)
return self.scorer(combined).squeeze(-1)
# ---------------------------------------------------------------------------
# Main Model
# ---------------------------------------------------------------------------
class UnicosysHypergraphModel(PreTrainedModel):
"""
Unicosys Hypergraph Knowledge Model.
Encodes the unified hypergraph as trainable embeddings with:
- Node type + subsystem structural embeddings
- Text-based semantic embeddings from labels/descriptions
- Graph attention for relational reasoning
- Link prediction for discovering missing evidence connections
Usage:
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True)
model = AutoModel.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True)
"""
config_class = UnicosysConfig
_tied_weights_keys = {}
supports_gradient_checkpointing = False
def __init__(self, config: UnicosysConfig):
super().__init__(config)
# Structural embeddings
self.node_type_embed = nn.Embedding(config.num_node_types, config.node_embed_dim)
self.subsystem_embed = nn.Embedding(config.num_subsystems, config.node_embed_dim)
self.node_id_embed = nn.Embedding(config.max_nodes, config.node_embed_dim)
# Project structural features to hidden dim
self.struct_proj = nn.Linear(config.node_embed_dim * 3, config.hidden_dim)
# Text encoder for labels
self.text_encoder = LightweightTextEncoder(config)
# Combine structural + text
self.combine_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
self.combine_norm = nn.LayerNorm(config.hidden_dim)
# Graph attention layers
self.gat_layers = nn.ModuleList([
GraphAttentionLayer(config) for _ in range(config.gat_num_layers)
])
# Edge type embeddings for GAT
self.edge_type_embed_gat = nn.Embedding(
config.num_edge_types, config.node_embed_dim
)
# Link prediction head
self.link_predictor = LinkPredictionHead(config)
# Initialize weights
self.apply(self._init_weights)
# Required by transformers >= 5.x for tied weight tracking
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def encode_nodes(
self,
node_ids: torch.LongTensor,
node_type_ids: torch.LongTensor,
subsystem_ids: torch.LongTensor,
text_input_ids: Optional[torch.LongTensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""Encode nodes into dense vectors of shape (N, hidden_dim)."""
struct = torch.cat([
self.node_id_embed(node_ids),
self.node_type_embed(node_type_ids),
self.subsystem_embed(subsystem_ids),
], dim=-1)
struct = self.struct_proj(struct)
if text_input_ids is not None:
text = self.text_encoder(text_input_ids, text_attention_mask)
combined = torch.cat([struct, text], dim=-1)
return self.combine_norm(self.combine_proj(combined))
else:
zeros = torch.zeros_like(struct)
combined = torch.cat([struct, zeros], dim=-1)
return self.combine_norm(self.combine_proj(combined))
def forward(
self,
node_ids: torch.LongTensor,
node_type_ids: torch.LongTensor,
subsystem_ids: torch.LongTensor,
edge_index: torch.LongTensor,
edge_type_ids: torch.LongTensor,
text_input_ids: Optional[torch.LongTensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
pos_edge_index: Optional[torch.LongTensor] = None,
pos_edge_types: Optional[torch.LongTensor] = None,
neg_edge_index: Optional[torch.LongTensor] = None,
neg_edge_types: Optional[torch.LongTensor] = None,
labels: Optional[torch.FloatTensor] = None,
):
"""
Forward pass with optional link prediction training.
Returns dict with:
- node_embeddings: (N, hidden_dim)
- loss: scalar (if labels provided)
- pos_scores: scores for positive edges
- neg_scores: scores for negative edges
"""
# 1. Encode all nodes
node_embeds = self.encode_nodes(
node_ids, node_type_ids, subsystem_ids,
text_input_ids, text_attention_mask,
)
# 2. Graph attention message passing
edge_type_embeds = self.edge_type_embed_gat(edge_type_ids)
for gat_layer in self.gat_layers:
node_embeds = gat_layer(node_embeds, edge_index, edge_type_embeds)
result = {"node_embeddings": node_embeds}
# 3. Link prediction (if training edges provided)
if pos_edge_index is not None and neg_edge_index is not None:
pos_src, pos_tgt = pos_edge_index
neg_src, neg_tgt = neg_edge_index
pos_scores = self.link_predictor(
node_embeds[pos_src], node_embeds[pos_tgt], pos_edge_types
)
neg_scores = self.link_predictor(
node_embeds[neg_src], node_embeds[neg_tgt], neg_edge_types
)
result["pos_scores"] = pos_scores
result["neg_scores"] = neg_scores
if labels is not None:
loss = F.margin_ranking_loss(
pos_scores, neg_scores,
torch.ones_like(pos_scores),
margin=self.config.margin,
)
result["loss"] = loss
return result
def get_node_embedding(self, node_idx: int) -> torch.Tensor:
"""Get the embedding for a single node by index."""
with torch.no_grad():
return self.node_id_embed.weight[node_idx]