from typing import Dict, List, Union from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer import torch.nn as nn import torch class BertForStorySkillClassification(BertPreTrainedModel): def __init__(self,config): super(BertForStorySkillClassification,self).__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.post_init() def forward(self,input_ids,attention_mask=None,labels=None,**kwargs): outputs = self.bert(input_ids,attention_mask=attention_mask) cls_hidden_state = outputs.last_hidden_state[:,0,:] ## [batch_size,seq_len,hidden_size] logits = self.classifier(cls_hidden_state) ## [batch_size,num_labels] if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1)) return loss return logits def predict( self, texts: Union[str, List[str]], tokenizer: PreTrainedTokenizer, batch_size: int = 32, return_probabilities: bool = False, device: Union[str, torch.device] = 'cpu', ) -> List[Dict]: """ 对输入文本进行分类预测。 Args: texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"] tokenizer: 分词器实例(需与模型兼容) batch_size: 批处理大小(提升推理速度) return_probabilities: 是否返回概率值(默认返回标签) device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备 Returns: 预测结果列表,格式为: [{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...] """ # 自动获取模型所在设备 if device is None: device = self.device # 统一输入格式为列表 if isinstance(texts, str): texts = [texts] # 结果存储 predictions = [] # 批处理预测 with torch.no_grad(): for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] # 分词并转换为张量 inputs = tokenizer( batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512, # 与BERT最大长度一致 ).to(device) # 模型推理 logits = self(**inputs) probs = torch.softmax(logits, dim=-1) scores, class_ids = torch.max(probs, dim=-1) # 转换为标签和分数 for text, class_id, score in zip(batch_texts, class_ids, scores): label = self.config.id2label[class_id.item()] result = {"text": text, "label": label} if return_probabilities: result["score"] = score.item() predictions.append(result) return predictions