PyTorch
cisa_berturk
custom_code
CISA-BERTurk-sentiment / modeling_cisa.py
ilterm's picture
Update modeling_cisa.py
f4712b8 verified
# coding: utf-8
"""
CISA-BERTurk-Sentiment: Cross-Individual Sentiment Analysis for Historical Turkish
DECA-EBSA (Dual-Encoder Context-Aware Entity-Based Sentiment Analysis) Architecture
İzmir Institute of Technology - Digital Humanities and AI Laboratory
TÜBİTAK Project No: 323K372
Usage:
from transformers import AutoTokenizer
from modeling_cisa import CISAModel
tokenizer = AutoTokenizer.from_pretrained("dbbiyte/CISA-BERTurk-sentiment")
model = CISAModel.from_pretrained("dbbiyte/CISA-BERTurk-sentiment", tokenizer=tokenizer)
result = model.predict(
text="Ali Bey'in vefatı bizleri elem-i azîme sevk etmişti.",
entity_text="Ali Bey",
entity_start=0,
entity_end=7,
)
print(result) # {'sentiment': 2, 'sentiment_label': 'Positive', ...}
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
from dataclasses import dataclass
from typing import Optional, List
SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"}
RELATION_LABELS = {0: "Indirect", 1: "Direct"}
# ─────────────────────────────────────────────────────────────
# Sub-modules (training koduyla birebir aynı)
# ─────────────────────────────────────────────────────────────
class AdaptiveFocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2.0, size_average=True, difficulty_weight=True):
super().__init__()
self.alpha, self.gamma = alpha, gamma
self.size_average, self.difficulty_weight = size_average, difficulty_weight
def forward(self, inputs, targets):
ce = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce)
alpha_t = 1.0
if self.alpha is not None:
alpha_t = self.alpha if isinstance(self.alpha, (float, int)) \
else self.alpha.gather(0, targets.data.view(-1))
fw = (1 - pt) ** self.gamma
if self.difficulty_weight:
fw = fw * (1 + torch.exp(-pt * 2))
loss = alpha_t * fw * ce
return loss.mean() if self.size_average else loss.sum()
class TurkishLinguisticFeatures(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.adjective_noun_attention = nn.MultiheadAttention(
hidden_size, num_heads=8, dropout=0.1, batch_first=True)
self.historical_word_projection = nn.Linear(hidden_size, hidden_size)
self.respect_pattern_detector = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), nn.Dropout(0.1),
nn.Linear(hidden_size // 2, 64))
self.formality_detector = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4), nn.GELU(), nn.Dropout(0.1),
nn.Linear(hidden_size // 4, 32))
self.morphological_analyzer = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), nn.Dropout(0.1),
nn.Linear(hidden_size // 2, 48))
self.linguistic_fusion = nn.Sequential(
nn.Linear(144, 128), nn.LayerNorm(128), nn.GELU(), nn.Dropout(0.1),
nn.Linear(128, 64))
def forward(self, text_repr, entity_repr):
enhanced, _ = self.adjective_noun_attention(
text_repr.unsqueeze(1), text_repr.unsqueeze(1), text_repr.unsqueeze(1))
enhanced = enhanced.squeeze(1)
combined = enhanced + self.historical_word_projection(enhanced)
return self.linguistic_fusion(torch.cat([
self.respect_pattern_detector(combined),
self.formality_detector(combined),
self.morphological_analyzer(combined),
], dim=-1))
class EnhancedEntityContextAttention(nn.Module):
def __init__(self, hidden_size, num_heads=12, dropout=0.1):
super().__init__()
self.entity_context_attention = nn.MultiheadAttention(
hidden_size, num_heads, dropout=dropout, batch_first=True)
self.position_embedding = nn.Embedding(512, hidden_size)
self.local_context_attention = nn.MultiheadAttention(
hidden_size, 8, dropout=dropout, batch_first=True)
self.hierarchical_attention = nn.Sequential(
nn.Linear(hidden_size * 3, hidden_size // 2), nn.Tanh(),
nn.Linear(hidden_size // 2, 3))
self.layer_norm1 = nn.LayerNorm(hidden_size)
self.layer_norm2 = nn.LayerNorm(hidden_size)
def _pos_weights(self, entity_positions, seq_len, device):
B = len(entity_positions)
W = torch.ones(B, seq_len, device=device)
for i, (s, e) in enumerate(entity_positions):
W[i, s:e+1] = 3.0
cs, ce = max(0, s-3), min(seq_len, e+4)
W[i, cs:s] = 2.0; W[i, e+1:ce] = 2.0
W[i, :cs] = 0.5; W[i, ce:] = 0.5
return W
def forward(self, entity_repr, text_sequence, entity_positions, attention_mask):
B, L, H = text_sequence.shape
dev = text_sequence.device
pos_ids = torch.arange(L, device=dev).unsqueeze(0).expand(B, -1)
enhanced = self.layer_norm1(text_sequence + self.position_embedding(pos_ids))
pw = self._pos_weights(entity_positions, L, dev)
eq = entity_repr.unsqueeze(1)
g_att, g_w = self.entity_context_attention(
eq, enhanced, enhanced, key_padding_mask=~attention_mask.bool())
g_att = g_att.squeeze(1)
w_att = torch.bmm((g_w.squeeze(1) * pw).unsqueeze(1), enhanced).squeeze(1)
local = []
for i, (s, e) in enumerate(entity_positions):
lc = enhanced[i, max(0,s-5):min(L,e+6)].unsqueeze(0)
if lc.size(1) > 0:
la, _ = self.local_context_attention(eq[i:i+1], lc, lc)
local.append(la.squeeze(1))
else:
local.append(entity_repr[i:i+1])
local = torch.cat(local, 0)
hw = F.softmax(self.hierarchical_attention(
torch.cat([g_att, w_att, local], -1)).view(-1, 3), -1)
final = hw[:,0:1]*g_att + hw[:,1:2]*w_att + hw[:,2:3]*local
return self.layer_norm2(final), g_w
class ContextualSentimentEncoder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.context_lstm = nn.LSTM(
hidden_size, hidden_size//2, num_layers=2,
bidirectional=True, dropout=0.1, batch_first=True)
self.sentiment_pooling = nn.MultiheadAttention(hidden_size, 8, batch_first=True)
self.context_type_classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size//2), nn.GELU(), nn.Dropout(0.1),
nn.Linear(hidden_size//2, 2))
def forward(self, context_sequence, entity_position_mask):
lstm_out, _ = self.context_lstm(context_sequence)
if entity_position_mask is not None:
pooled = (lstm_out * entity_position_mask.unsqueeze(-1).float()).mean(1)
else:
pooled = lstm_out.mean(1)
sc, _ = self.sentiment_pooling(pooled.unsqueeze(1), lstm_out, lstm_out)
return sc.squeeze(1), self.context_type_classifier(pooled)
# ─────────────────────────────────────────────────────────────
# Core model (nn.Module — training koduyla birebir)
# ─────────────────────────────────────────────────────────────
class PositionAwareDualEncoderEBSA(nn.Module):
"""Training koduyla birebir aynı mimari."""
def __init__(self, model_name='dbmdz/bert-base-turkish-cased',
num_sentiment_labels=3, dropout_rate=0.1,
use_r_drop=True, stochastic_depth_rate=0.1):
super().__init__()
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.text_encoder = AutoModel.from_pretrained(model_name, add_pooling_layer=False)
self.entity_encoder = AutoModel.from_pretrained(model_name, add_pooling_layer=False)
self.config = self.text_encoder.config
self.dropout_rate = dropout_rate
self.use_r_drop = use_r_drop
self.clip_grad_norm = 1.0
self.stochastic_depth_rate = stochastic_depth_rate
H = self.text_encoder.config.hidden_size # 768
self.enhanced_attention = EnhancedEntityContextAttention(H, 12, dropout_rate)
self.turkish_linguistic = TurkishLinguisticFeatures(H)
self.contextual_encoder = ContextualSentimentEncoder(H)
self.position_embedding = nn.Embedding(512, H)
self.entity_position_proj = nn.Linear(H, H)
self.layer_norm1 = nn.LayerNorm(H)
self.layer_norm2 = nn.LayerNorm(H)
self.layer_norm3 = nn.LayerNorm(H * 2) # state_dict uyumu — training kodunda var
self.dropout = nn.Dropout(dropout_rate)
self.enhanced_fusion = nn.Sequential(
nn.Linear(H*3+64, H*2), nn.LayerNorm(H*2), nn.GELU(), nn.Dropout(dropout_rate),
nn.Linear(H*2, H), nn.LayerNorm(H))
self.sentiment_classifier = nn.Sequential(
nn.Linear(H+2, H//2), nn.LayerNorm(H//2), nn.GELU(), nn.Dropout(dropout_rate),
nn.Linear(H//2, num_sentiment_labels))
self.relation_classifier = nn.Sequential(
nn.Linear(H+2, H//2), nn.LayerNorm(H//2), nn.GELU(), nn.Dropout(dropout_rate),
nn.Linear(H//2, 2))
self.label_smoothing = 0.1
num_layers = len(self.text_encoder.encoder.layer)
self.layer_drop_probs = [stochastic_depth_rate * i / num_layers for i in range(num_layers)]
def _weighted_layers(self, encoder, input_ids, attn_mask):
out = encoder(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True)
last4 = torch.stack(out.hidden_states[-4:])
w = torch.tensor([0.1,0.2,0.3,0.4], device=input_ids.device).view(4,1,1,1)
return (last4 * w).sum(0), out.last_hidden_state
def _entity_repr(self, text_out, entity_positions, position_mask):
reprs = []
for i in range(text_out.size(0)):
pm = torch.tensor(position_mask[i], device=text_out.device, dtype=torch.bool)
toks = text_out[i][pm]
reprs.append(toks.mean(0) if toks.size(0) > 0 else text_out[i, 0])
return torch.stack(reprs)
def forward(self, text_input_ids, text_attention_mask,
entity_input_ids, entity_attention_mask,
entity_positions, position_mask,
sentiment_label=None, relation_label=None, **kwargs):
dev = text_input_ids.device
text_out, _ = self._weighted_layers(self.text_encoder, text_input_ids, text_attention_mask)
entity_out, _ = self._weighted_layers(self.entity_encoder, entity_input_ids, entity_attention_mask)
text_cls = text_out[:, 0, :]
entity_cls = entity_out[:, 0, :]
cross_att, att_w = self.enhanced_attention(
entity_cls, text_out, entity_positions, text_attention_mask)
turkish_feat = self.turkish_linguistic(text_cls, entity_cls)
pm_tensor = torch.stack([torch.tensor(m, device=dev) for m in position_mask])
ctx_repr, ctx_type_logits = self.contextual_encoder(text_out, pm_tensor)
text_cls = F.normalize(text_cls, p=2, dim=1)
entity_cls = F.normalize(entity_cls, p=2, dim=1)
cross_att = F.normalize(cross_att, p=2, dim=1)
entity_cls = self.layer_norm1(entity_cls)
cross_att = self.layer_norm2(cross_att + entity_cls)
fused = self.enhanced_fusion(
torch.cat([text_cls, entity_cls, cross_att, turkish_feat], 1))
ctx_probs = F.softmax(ctx_type_logits, -1)
clf_input = torch.cat([fused, ctx_probs], 1)
sentiment_logits = self.sentiment_classifier(clf_input)
relation_logits = self.relation_classifier(clf_input)
loss = None
if sentiment_label is not None and relation_label is not None:
loss_fn = AdaptiveFocalLoss(alpha=0.25, gamma=2.0, difficulty_weight=True)
loss = loss_fn(sentiment_logits, sentiment_label) + \
loss_fn(relation_logits, relation_label)
return {
'loss': loss,
'sentiment_logits': sentiment_logits,
'relation_logits': relation_logits,
'attention_weights': att_w,
'context_type_logits': ctx_type_logits,
}
# ─────────────────────────────────────────────────────────────
# Public wrapper — kullanıcının doğrudan kullandığı sınıf
# ─────────────────────────────────────────────────────────────
class CISAModel:
"""
Kullanıcı arayüzü. AutoModel gerektirmez, trust_remote_code gerektirmez.
from modeling_cisa import CISAModel
model = CISAModel.from_pretrained("dbbiyte/CISA-BERTurk-sentiment")
result = model.predict("Ali Bey'in vefatı...", "Ali Bey", 0, 7)
"""
SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"}
RELATION_LABELS = {0: "Indirect", 1: "Direct"}
def __init__(self, model: PositionAwareDualEncoderEBSA, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
# ── factory ──────────────────────────────────────────────
@classmethod
def from_pretrained(cls, repo_id: str = "dbbiyte/CISA-BERTurk-sentiment",
device: Optional[torch.device] = None):
"""
HuggingFace repo'sundan model + tokenizer yükle.
pytorch_model.bin doğrudan state_dict olarak yüklenir.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Tokenizer yükleniyor: {repo_id}")
tokenizer = AutoTokenizer.from_pretrained(repo_id)
print("Model mimarisi oluşturuluyor...")
core = PositionAwareDualEncoderEBSA(
model_name='dbmdz/bert-base-turkish-cased',
num_sentiment_labels=3,
dropout_rate=0.1,
use_r_drop=False,
stochastic_depth_rate=0.1,
)
print("Ağırlıklar indiriliyor...")
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
state_dict = torch.load(weights_path, map_location=device)
missing, unexpected = core.load_state_dict(state_dict, strict=False)
if missing:
print(f" Eksik key'ler : {missing}")
if unexpected:
print(f" Beklenmeyen key: {unexpected}")
core.to(device).eval()
print("Model hazır.")
return cls(core, tokenizer)
# ── inference ────────────────────────────────────────────
@torch.no_grad()
def predict(self, text: str, entity_text: str,
entity_start: int, entity_end: int,
max_length: int = 256) -> dict:
"""
Args:
text: Tam metin pasajı.
entity_text: Kişi adı (metinde geçtiği haliyle).
entity_start: entity_text'in metindeki karakter başlangıç ofseti.
entity_end: entity_text'in metindeki karakter bitiş ofseti.
max_length: Maksimum token uzunluğu (varsayılan 256).
Returns:
{
"sentiment": int, # 0=Negative 1=Neutral 2=Positive
"sentiment_label": str,
"sentiment_probs": list[float],
"relation": int, # 0=Indirect 1=Direct
"relation_label": str,
"relation_probs": list[float],
}
"""
# text tokenization
txt_tok = self.tokenizer(
text, padding='max_length', truncation=True,
max_length=max_length, return_tensors="pt",
return_token_type_ids=False, return_offsets_mapping=True)
offset_map = txt_tok.pop("offset_mapping")[0]
# entity tokenization (training koduyla aynı format)
context = text[entity_end: min(len(text), entity_end + 1800)]
ent_input = f"[CLS] {entity_text} [SEP] {context} [SEP]"
ent_tok = self.tokenizer(
ent_input, padding='max_length', truncation=True,
max_length=max_length // 2, return_tensors="pt",
return_token_type_ids=False)
# entity token pozisyonlarını bul
s_tok, e_tok = 0, 0
for idx, (cs, ce) in enumerate(offset_map.tolist()):
if cs <= entity_start < ce:
s_tok = idx
if cs < entity_end <= ce:
e_tok = idx
break
pos_mask = [0] * max_length
for idx in range(s_tok, min(e_tok + 1, max_length)):
pos_mask[idx] = 1
inputs = {
"text_input_ids": txt_tok["input_ids"].to(self.device),
"text_attention_mask": txt_tok["attention_mask"].to(self.device),
"entity_input_ids": ent_tok["input_ids"].to(self.device),
"entity_attention_mask": ent_tok["attention_mask"].to(self.device),
"entity_positions": [[s_tok, e_tok]],
"position_mask": [pos_mask],
}
out = self.model(**inputs)
s_probs = F.softmax(out["sentiment_logits"], -1)[0].cpu().tolist()
r_probs = F.softmax(out["relation_logits"], -1)[0].cpu().tolist()
s_pred = int(torch.argmax(out["sentiment_logits"], -1).item())
r_pred = int(torch.argmax(out["relation_logits"], -1).item())
return {
"sentiment": s_pred,
"sentiment_label": self.SENTIMENT_LABELS[s_pred],
"sentiment_probs": s_probs,
"relation": r_pred,
"relation_label": self.RELATION_LABELS[r_pred],
"relation_probs": r_probs,
}