ChatBot_Yte / src /rag /predict.py
giangpvg's picture
Update Data
9af9229
Raw
History Blame Contribute Delete
4.02 kB
import torch
from transformers import AutoTokenizer
from model_intent import JointPhoBERTModel
from preprocess import LabelEncoder
import json
import warnings
from underthesea import word_tokenize
warnings.filterwarnings('ignore')
def load_encoder(train_path):
with open(train_path, 'r', encoding='utf-8') as f:
train_data = json.load(f)
encoder = LabelEncoder()
encoder.fit(train_data)
return encoder
def predict(sentence, model, tokenizer, encoder, device, max_len=128):
model.eval()
# Sử dụng underthesea để tách từ (Word Segmentation) với dấu '_'
sentence = word_tokenize(sentence, format="text")
words = sentence.split()
if not words:
return "UNKNOWN", []
input_ids = [tokenizer.cls_token_id]
word_to_subword_map = [] # Lưu vị trí subword đầu tiên của mỗi từ
for word in words:
word_tokens = tokenizer.encode(word, add_special_tokens=False)
if not word_tokens:
continue
word_to_subword_map.append(len(input_ids)) # Vị trí của subword đầu tiên
input_ids.extend(word_tokens)
input_ids.append(tokenizer.sep_token_id)
if len(input_ids) > max_len:
input_ids = input_ids[:max_len]
input_ids[-1] = tokenizer.sep_token_id
attention_mask = [1] * len(input_ids)
# Padding
padding_length = max_len - len(input_ids)
if padding_length > 0:
input_ids.extend([tokenizer.pad_token_id] * padding_length)
attention_mask.extend([0] * padding_length)
input_ids_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long).to(device)
with torch.no_grad():
intent_logits, ner_logits = model(input_ids_tensor, attention_mask_tensor)
intent_id = torch.argmax(intent_logits, dim=1).item()
intent_label = encoder.id2intent.get(intent_id, "UNKNOWN")
ner_preds = torch.argmax(ner_logits, dim=2).squeeze(0).cpu().numpy()
extracted_entities = []
# Trích xuất nhãn NER cho từng từ gốc
for i, word in enumerate(words):
if i < len(word_to_subword_map):
subword_idx = word_to_subword_map[i]
if subword_idx < max_len:
tag_id = ner_preds[subword_idx]
tag_label = encoder.id2ner.get(tag_id, "O")
extracted_entities.append((word, tag_label))
else:
extracted_entities.append((word, "O"))
return intent_label, extracted_entities
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "vinai/phobert-base-v2"
print("Loading Encoder and Tokenizer...")
encoder = load_encoder('../data/train.json')
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading Model Weights...")
model = JointPhoBERTModel(model_name, encoder.get_num_intents(), encoder.get_num_ner_tags())
model.load_state_dict(torch.load("checkpoints/best_joint_model.pth", map_location=device))
model.to(device)
print("\n" + "="*50)
print("MÔ HÌNH ĐÃ SẴN SÀNG!")
print("="*50 + "\n")
while True:
text = input("Nhập câu hỏi (hoặc 'q' để thoát): ")
if text.strip().lower() == 'q':
break
if not text.strip():
continue
intent, entities = predict(text, model, tokenizer, encoder, device)
print(f"\n--- KẾT QUẢ DỰ ĐOÁN ---")
print(f"🔹 Intent: {intent}")
print(f"🔹 Thực thể (NER):")
has_entity = False
for word, tag in entities:
if tag != "O":
print(f" [{tag}] -> {word.replace('_', ' ')}")
has_entity = True
if not has_entity:
print(" (Không tìm thấy thực thể nào)")
print("-" * 30 + "\n")