""" Unicosys Hypergraph Knowledge Model A trainable knowledge graph embedding model that encodes the unified hypergraph (entities, evidence, transactions, communications) as learned vector representations. Load with: from transformers import AutoConfig, AutoModel config = AutoConfig.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True) model = AutoModel.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True) """ import json import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from .configuration_unicosys import UnicosysConfig # --------------------------------------------------------------------------- # Text Encoder (lightweight) # --------------------------------------------------------------------------- class LightweightTextEncoder(nn.Module): """A small transformer encoder for node labels and descriptions.""" def __init__(self, config: UnicosysConfig): super().__init__() self.token_embed = nn.Embedding(config.text_vocab_size, config.text_embed_dim) self.pos_embed = nn.Embedding(config.text_max_length, config.text_embed_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=config.text_embed_dim, nhead=config.text_num_heads, dim_feedforward=config.text_embed_dim * 4, dropout=config.gat_dropout, batch_first=True, ) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=config.text_num_layers ) self.pool_proj = nn.Linear(config.text_embed_dim, config.hidden_dim) def forward(self, input_ids, attention_mask=None): B, L = input_ids.shape positions = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1) x = self.token_embed(input_ids) + self.pos_embed(positions) if attention_mask is not None: src_key_padding_mask = attention_mask == 0 else: src_key_padding_mask = None x = self.encoder(x, src_key_padding_mask=src_key_padding_mask) if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) else: pooled = x.mean(dim=1) return self.pool_proj(pooled) # --------------------------------------------------------------------------- # Graph Attention Layer # --------------------------------------------------------------------------- class GraphAttentionLayer(nn.Module): """Multi-head graph attention for hypergraph node updates.""" def __init__(self, config: UnicosysConfig): super().__init__() self.num_heads = config.gat_num_heads self.head_dim = config.hidden_dim // config.gat_num_heads assert self.head_dim * self.num_heads == config.hidden_dim self.q_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.k_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.v_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.edge_proj = nn.Linear(config.node_embed_dim, config.hidden_dim) self.out_proj = nn.Linear(config.hidden_dim, config.hidden_dim) self.norm = nn.LayerNorm(config.hidden_dim) self.dropout = nn.Dropout(config.gat_dropout) def forward(self, node_embeds, edge_index, edge_type_embeds): N = node_embeds.size(0) src, tgt = edge_index q = self.q_proj(node_embeds[tgt]) k = self.k_proj(node_embeds[src]) v = self.v_proj(node_embeds[src]) edge_bias = self.edge_proj(edge_type_embeds) k = k + edge_bias q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_heads, self.head_dim) v = v.view(-1, self.num_heads, self.head_dim) attn = (q * k).sum(dim=-1) / math.sqrt(self.head_dim) attn_max = torch.zeros(N, self.num_heads, device=attn.device) attn_max.scatter_reduce_(0, tgt.unsqueeze(1).expand_as(attn), attn, reduce="amax") attn = torch.exp(attn - attn_max[tgt]) attn_sum = torch.zeros(N, self.num_heads, device=attn.device) attn_sum.scatter_add_(0, tgt.unsqueeze(1).expand_as(attn), attn) attn = attn / attn_sum[tgt].clamp(min=1e-8) attn = self.dropout(attn) weighted = v * attn.unsqueeze(-1) weighted = weighted.view(-1, self.num_heads * self.head_dim) out = torch.zeros(N, self.num_heads * self.head_dim, device=weighted.device) out.scatter_add_(0, tgt.unsqueeze(1).expand_as(weighted), weighted) out = self.out_proj(out) return self.norm(node_embeds + out) # --------------------------------------------------------------------------- # Link Prediction Head # --------------------------------------------------------------------------- class LinkPredictionHead(nn.Module): """Scores candidate edges for link prediction training.""" def __init__(self, config: UnicosysConfig): super().__init__() self.edge_type_embed = nn.Embedding(config.num_edge_types, config.hidden_dim) self.scorer = nn.Sequential( nn.Linear(config.hidden_dim * 3, config.hidden_dim), nn.ReLU(), nn.Dropout(config.gat_dropout), nn.Linear(config.hidden_dim, 1), ) def forward(self, src_embeds, tgt_embeds, edge_type_ids): edge_embeds = self.edge_type_embed(edge_type_ids) combined = torch.cat([src_embeds, tgt_embeds, edge_embeds], dim=-1) return self.scorer(combined).squeeze(-1) # --------------------------------------------------------------------------- # Main Model # --------------------------------------------------------------------------- class UnicosysHypergraphModel(PreTrainedModel): """ Unicosys Hypergraph Knowledge Model. Encodes the unified hypergraph as trainable embeddings with: - Node type + subsystem structural embeddings - Text-based semantic embeddings from labels/descriptions - Graph attention for relational reasoning - Link prediction for discovering missing evidence connections Usage: from transformers import AutoConfig, AutoModel config = AutoConfig.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True) model = AutoModel.from_pretrained("drzo/unicosys-hypergraph", trust_remote_code=True) """ config_class = UnicosysConfig _tied_weights_keys = {} supports_gradient_checkpointing = False def __init__(self, config: UnicosysConfig): super().__init__(config) # Structural embeddings self.node_type_embed = nn.Embedding(config.num_node_types, config.node_embed_dim) self.subsystem_embed = nn.Embedding(config.num_subsystems, config.node_embed_dim) self.node_id_embed = nn.Embedding(config.max_nodes, config.node_embed_dim) # Project structural features to hidden dim self.struct_proj = nn.Linear(config.node_embed_dim * 3, config.hidden_dim) # Text encoder for labels self.text_encoder = LightweightTextEncoder(config) # Combine structural + text self.combine_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim) self.combine_norm = nn.LayerNorm(config.hidden_dim) # Graph attention layers self.gat_layers = nn.ModuleList([ GraphAttentionLayer(config) for _ in range(config.gat_num_layers) ]) # Edge type embeddings for GAT self.edge_type_embed_gat = nn.Embedding( config.num_edge_types, config.node_embed_dim ) # Link prediction head self.link_predictor = LinkPredictionHead(config) # Initialize weights self.apply(self._init_weights) # Required by transformers >= 5.x for tied weight tracking self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def encode_nodes( self, node_ids: torch.LongTensor, node_type_ids: torch.LongTensor, subsystem_ids: torch.LongTensor, text_input_ids: Optional[torch.LongTensor] = None, text_attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """Encode nodes into dense vectors of shape (N, hidden_dim).""" struct = torch.cat([ self.node_id_embed(node_ids), self.node_type_embed(node_type_ids), self.subsystem_embed(subsystem_ids), ], dim=-1) struct = self.struct_proj(struct) if text_input_ids is not None: text = self.text_encoder(text_input_ids, text_attention_mask) combined = torch.cat([struct, text], dim=-1) return self.combine_norm(self.combine_proj(combined)) else: zeros = torch.zeros_like(struct) combined = torch.cat([struct, zeros], dim=-1) return self.combine_norm(self.combine_proj(combined)) def forward( self, node_ids: torch.LongTensor, node_type_ids: torch.LongTensor, subsystem_ids: torch.LongTensor, edge_index: torch.LongTensor, edge_type_ids: torch.LongTensor, text_input_ids: Optional[torch.LongTensor] = None, text_attention_mask: Optional[torch.LongTensor] = None, pos_edge_index: Optional[torch.LongTensor] = None, pos_edge_types: Optional[torch.LongTensor] = None, neg_edge_index: Optional[torch.LongTensor] = None, neg_edge_types: Optional[torch.LongTensor] = None, labels: Optional[torch.FloatTensor] = None, ): """ Forward pass with optional link prediction training. Returns dict with: - node_embeddings: (N, hidden_dim) - loss: scalar (if labels provided) - pos_scores: scores for positive edges - neg_scores: scores for negative edges """ # 1. Encode all nodes node_embeds = self.encode_nodes( node_ids, node_type_ids, subsystem_ids, text_input_ids, text_attention_mask, ) # 2. Graph attention message passing edge_type_embeds = self.edge_type_embed_gat(edge_type_ids) for gat_layer in self.gat_layers: node_embeds = gat_layer(node_embeds, edge_index, edge_type_embeds) result = {"node_embeddings": node_embeds} # 3. Link prediction (if training edges provided) if pos_edge_index is not None and neg_edge_index is not None: pos_src, pos_tgt = pos_edge_index neg_src, neg_tgt = neg_edge_index pos_scores = self.link_predictor( node_embeds[pos_src], node_embeds[pos_tgt], pos_edge_types ) neg_scores = self.link_predictor( node_embeds[neg_src], node_embeds[neg_tgt], neg_edge_types ) result["pos_scores"] = pos_scores result["neg_scores"] = neg_scores if labels is not None: loss = F.margin_ranking_loss( pos_scores, neg_scores, torch.ones_like(pos_scores), margin=self.config.margin, ) result["loss"] = loss return result def get_node_embedding(self, node_idx: int) -> torch.Tensor: """Get the embedding for a single node by index.""" with torch.no_grad(): return self.node_id_embed.weight[node_idx]