|
|
--- |
|
|
datasets: |
|
|
- rungalileo/ragbench |
|
|
language: |
|
|
- en |
|
|
- ru |
|
|
base_model: |
|
|
- MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli |
|
|
--- |
|
|
|
|
|
|
|
|
# DebertaTrace Model |
|
|
|
|
|
Карточка модели для token classification классификации ответов RAG-модели без оконного прохода по тексту, аналогчному в Luna. На выходе — три логита: релевантность, использование и приверженность (правдивость). |
|
|
## Пример использования |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoModel |
|
|
from torch import nn |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("CMCenjoyer/deberta-trace") |
|
|
|
|
|
|
|
|
class DebertaTrace(nn.Module): |
|
|
def __init__(self, base_model): |
|
|
super().__init__() |
|
|
self.base = base_model |
|
|
hid = base_model.config.hidden_size |
|
|
self.rel_head = nn.Linear(hid,1) |
|
|
self.util_head = nn.Linear(hid,1) |
|
|
self.adh_head = nn.Linear(hid,1) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
out = self.base(input_ids=input_ids, attention_mask=attention_mask) |
|
|
hs = out.last_hidden_state |
|
|
return { |
|
|
'logits_relevance': self.rel_head(hs).squeeze(-1), |
|
|
'logits_utilization': self.util_head(hs).squeeze(-1), |
|
|
'logits_adherence': self.adh_head(hs).squeeze(-1) |
|
|
} |
|
|
|
|
|
base_model = AutoModel.from_pretrained("CMCenjoyer/deberta-trace") |
|
|
model = DebertaTrace(base_model) |
|
|
# heads_weights.p в локальный кэш |
|
|
file_path = hf_hub_download(repo_id="CMCenjoyer/deberta-trace", filename="heads_weights.pt") |
|
|
heads_weights = torch.load(file_path, weights_only=True) |
|
|
model.rel_head.load_state_dict(heads_weights['rel_head']) |
|
|
model.util_head.load_state_dict(heads_weights['util_head']) |
|
|
model.adh_head.load_state_dict(heads_weights['adh_head']) |
|
|
def preprocess(example, max_length=512): |
|
|
''' |
|
|
Препроцессим входной элемент в маску контекста, маску ответва и input_ids + attention_mask |
|
|
''' |
|
|
question_ids = tokenizer.encode(example["question"], add_special_tokens=False) |
|
|
|
|
|
doc_ids = [] |
|
|
for doc in example["documents_sentences"]: |
|
|
for _, sent in doc: |
|
|
tokens = tokenizer.encode(sent, add_special_tokens=False) |
|
|
doc_ids += tokens |
|
|
|
|
|
response_ids = tokenizer.encode(example["response"], add_special_tokens=False) |
|
|
|
|
|
sep_id = tokenizer.sep_token_id |
|
|
input_ids = question_ids + [sep_id] + doc_ids + [sep_id] + response_ids |
|
|
|
|
|
context_mask = [0] * (len(question_ids) + 1) + [1] * len(doc_ids) + [0] + [0] * len(response_ids) |
|
|
response_mask = [0] * (len(question_ids) + len(doc_ids) + 2) + [1] * len(response_ids) |
|
|
|
|
|
if len(input_ids) > max_length: |
|
|
input_ids = input_ids[:max_length] |
|
|
context_mask = context_mask[:max_length] |
|
|
response_mask = response_mask[:max_length] |
|
|
|
|
|
return { |
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long), |
|
|
"attention_mask": torch.tensor([1] * len(input_ids), dtype=torch.long), |
|
|
"context_mask": torch.tensor(context_mask, dtype=torch.bool), |
|
|
"response_mask": torch.tensor(response_mask, dtype=torch.bool), |
|
|
} |
|
|
def compute_trace_metrics_inference(logits, masks, threshold=0.5): |
|
|
''' |
|
|
подсчет метрик TRACE для каждого элемента батча(все батчи должны быть фиксированной одной длины) |
|
|
''' |
|
|
rel_pred = (torch.sigmoid(logits['logits_relevance'].detach().cpu()) > threshold) |
|
|
util_pred = (torch.sigmoid(logits['logits_utilization'].detach().cpu())> threshold) |
|
|
adh_pred = (torch.sigmoid(logits['logits_adherence'].detach().cpu()) > threshold) |
|
|
|
|
|
ctx_m = masks['context_mask'].detach().cpu() |
|
|
resp_m = masks['response_mask'].detach().cpu() |
|
|
|
|
|
def rate(pred, mask): |
|
|
# sum(pred & mask) / sum(mask) |
|
|
num = (pred & mask).sum(dim=1).float() |
|
|
den = mask.sum(dim=1).float().clamp(min=1) |
|
|
return num.div(den) |
|
|
|
|
|
relevance_rate = rate(rel_pred, ctx_m) |
|
|
utilization_rate = rate(util_pred, ctx_m) |
|
|
adherence_rate = rate(adh_pred, resp_m) |
|
|
|
|
|
# completeness: из релевантных предсказаний — сколько ещё и util |
|
|
num_ru = (rel_pred & util_pred & ctx_m).sum(dim=1).float() |
|
|
den_r = rel_pred.sum(dim=1).float().clamp(min=1) |
|
|
completeness = num_ru.div(den_r) |
|
|
|
|
|
return { |
|
|
'relevance_rate': relevance_rate, |
|
|
'utilization_rate': utilization_rate, |
|
|
'adherence_rate': adherence_rate, |
|
|
'completeness': completeness |
|
|
} |
|
|
from datasets import load_dataset |
|
|
ds = load_dataset("rungalileo/ragbench", "delucionqa") |
|
|
ex = preprocess(ds['train'][9]) |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model(ex["input_ids"].unsqueeze(0), ex["attention_mask"].unsqueeze(0)) |
|
|
batch_metrics = compute_trace_metrics_inference(outputs, {'context_mask': ex["context_mask"].unsqueeze(0) , 'response_mask':ex["response_mask"].unsqueeze(0)}) |
|
|
batch_metrics |