import torch from src.model import CRF_Tagger from src.preprocessing import process_demo_sentence import os def predict(model, loader, count_loss=True): model.eval() # Evaluation Mode, Ignore Dropout, BatchNorm, ... all_preds, all_true = [], [] loss = 0.0 with torch.no_grad(): # Stop track gradient for x, y, _ in loader: mask = (y != -1) # Get loss if count_loss: loss += model(x, y, mask).item() # Get prediction preds = model.decode(x, mask) # Loop for each sentence in mini-batch for pred_seq, true_seq, m in zip(preds, y, mask): true_labels = true_seq[m].tolist() # tensor[mask tensor boolean] all_preds.extend(pred_seq) all_true.extend(true_labels) return all_preds, all_true, loss/len(loader) def predict_demo(text): BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) model_path = os.path.join(BASE_DIR, "models", "best_epoch_16.pt") id_tag = {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG', 4: 'I-ORG', 5: 'B-LOC', 6: 'I-LOC'} x, tokens = process_demo_sentence(text) # 1 x seq_length x 768 NUM_TAGS = 7 model = CRF_Tagger(input_dim=x.size(2), num_tags=NUM_TAGS) model.load_state_dict(torch.load(model_path)) model.eval() with torch.no_grad(): preds = model.decode(x) labels = [id_tag[lab] for lab in preds[0]] # preds[0] vì sẽ trả về nhiều batch nhưng chúng ta chỉ có 1 return tokens, labels