--- datasets: - rungalileo/ragbench language: - en - ru base_model: - MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli --- # DebertaTrace Model Карточка модели для token classification классификации ответов RAG-модели без оконного прохода по тексту, аналогчному в Luna. На выходе — три логита: релевантность, использование и приверженность (правдивость). ## Пример использования ```python import torch from transformers import AutoModel from torch import nn from huggingface_hub import hf_hub_download from transformers import AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("CMCenjoyer/deberta-trace") class DebertaTrace(nn.Module): def __init__(self, base_model): super().__init__() self.base = base_model hid = base_model.config.hidden_size self.rel_head = nn.Linear(hid,1) self.util_head = nn.Linear(hid,1) self.adh_head = nn.Linear(hid,1) def forward(self, input_ids, attention_mask): out = self.base(input_ids=input_ids, attention_mask=attention_mask) hs = out.last_hidden_state return { 'logits_relevance': self.rel_head(hs).squeeze(-1), 'logits_utilization': self.util_head(hs).squeeze(-1), 'logits_adherence': self.adh_head(hs).squeeze(-1) } base_model = AutoModel.from_pretrained("CMCenjoyer/deberta-trace") model = DebertaTrace(base_model) # heads_weights.p в локальный кэш file_path = hf_hub_download(repo_id="CMCenjoyer/deberta-trace", filename="heads_weights.pt") heads_weights = torch.load(file_path, weights_only=True) model.rel_head.load_state_dict(heads_weights['rel_head']) model.util_head.load_state_dict(heads_weights['util_head']) model.adh_head.load_state_dict(heads_weights['adh_head']) def preprocess(example, max_length=512): ''' Препроцессим входной элемент в маску контекста, маску ответва и input_ids + attention_mask ''' question_ids = tokenizer.encode(example["question"], add_special_tokens=False) doc_ids = [] for doc in example["documents_sentences"]: for _, sent in doc: tokens = tokenizer.encode(sent, add_special_tokens=False) doc_ids += tokens response_ids = tokenizer.encode(example["response"], add_special_tokens=False) sep_id = tokenizer.sep_token_id input_ids = question_ids + [sep_id] + doc_ids + [sep_id] + response_ids context_mask = [0] * (len(question_ids) + 1) + [1] * len(doc_ids) + [0] + [0] * len(response_ids) response_mask = [0] * (len(question_ids) + len(doc_ids) + 2) + [1] * len(response_ids) if len(input_ids) > max_length: input_ids = input_ids[:max_length] context_mask = context_mask[:max_length] response_mask = response_mask[:max_length] return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor([1] * len(input_ids), dtype=torch.long), "context_mask": torch.tensor(context_mask, dtype=torch.bool), "response_mask": torch.tensor(response_mask, dtype=torch.bool), } def compute_trace_metrics_inference(logits, masks, threshold=0.5): ''' подсчет метрик TRACE для каждого элемента батча(все батчи должны быть фиксированной одной длины) ''' rel_pred = (torch.sigmoid(logits['logits_relevance'].detach().cpu()) > threshold) util_pred = (torch.sigmoid(logits['logits_utilization'].detach().cpu())> threshold) adh_pred = (torch.sigmoid(logits['logits_adherence'].detach().cpu()) > threshold) ctx_m = masks['context_mask'].detach().cpu() resp_m = masks['response_mask'].detach().cpu() def rate(pred, mask): # sum(pred & mask) / sum(mask) num = (pred & mask).sum(dim=1).float() den = mask.sum(dim=1).float().clamp(min=1) return num.div(den) relevance_rate = rate(rel_pred, ctx_m) utilization_rate = rate(util_pred, ctx_m) adherence_rate = rate(adh_pred, resp_m) # completeness: из релевантных предсказаний — сколько ещё и util num_ru = (rel_pred & util_pred & ctx_m).sum(dim=1).float() den_r = rel_pred.sum(dim=1).float().clamp(min=1) completeness = num_ru.div(den_r) return { 'relevance_rate': relevance_rate, 'utilization_rate': utilization_rate, 'adherence_rate': adherence_rate, 'completeness': completeness } from datasets import load_dataset ds = load_dataset("rungalileo/ragbench", "delucionqa") ex = preprocess(ds['train'][9]) model.eval() with torch.no_grad(): outputs = model(ex["input_ids"].unsqueeze(0), ex["attention_mask"].unsqueeze(0)) batch_metrics = compute_trace_metrics_inference(outputs, {'context_mask': ex["context_mask"].unsqueeze(0) , 'response_mask':ex["response_mask"].unsqueeze(0)}) batch_metrics