ManB2207540's picture
generate demo
c750faa
# finetuned model
# Import các thư viện cần thiết
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
import torch
import os
import pandas as pd
import time
from datasets import Dataset
# Hàm tải dữ liệu Parquet với xử lý lỗi và thử lại
def load_squad_parquet(split='train', max_retries=3, delay=5):
splits = {'train': 'plain_text/train-00000-of-00001.parquet'}
path = "hf://datasets/rajpurkar/squad/" + splits[split]
for attempt in range(max_retries):
try:
df = pd.read_parquet(path)
print(f"Tải tập dữ liệu SQuAD {split} thành công sau {attempt + 1} lần thử!")
return df
except Exception as e:
print(f"Lần thử {attempt + 1}/{max_retries} thất bại: {e}")
if attempt < max_retries - 1:
print(f"Đợi {delay} giây trước khi thử lại...")
time.sleep(delay)
else:
print("Đã hết số lần thử. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.")
return None
# Tải tập dữ liệu SQuAD (chỉ tải train để kiểm tra)
train_df = load_squad_parquet('train')
if train_df is None:
raise ValueError("Không thể tải tập dữ liệu SQuAD. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.")
# Chuyển đổi DataFrame thành Dataset để tương thích với pipeline
train_ds = Dataset.from_pandas(train_df)
# Đường dẫn đến thư mục chứa mô hình và tokenizer đã tinh chỉnh
model_dir = "/Users/trantieuman/Downloads/prophetnet_1epoch/prophetnet_context_to_question_finetuned"
# Kiểm tra xem thư mục tồn tại
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Thư mục {model_dir} không tồn tại. Vui lòng kiểm tra lại đường dẫn.")
# Danh sách file cần thiết cho mô hình và tokenizer
required_model_files = ['config.json', 'model.safetensors'] # Chỉ cần model.safetensors vì đã sử dụng định dạng này
required_tokenizer_files = ['prophetnet.tokenizer', 'tokenizer_config.json'] # File tokenizer cần thiết
all_files = os.listdir(model_dir)
missing_model_files = [f for f in required_model_files if f not in all_files]
missing_tokenizer_files = [f for f in required_tokenizer_files if f not in all_files]
if missing_model_files or missing_tokenizer_files:
print(f"Thiếu file trong {model_dir}:")
if missing_model_files:
print(f" - File mô hình thiếu: {missing_model_files}")
if missing_tokenizer_files:
print(f" - File tokenizer thiếu: {missing_tokenizer_files}")
raise FileNotFoundError("Vui lòng cung cấp đầy đủ file mô hình và tokenizer.")
# Khởi tạo tokenizer và mô hình từ thư mục đã tinh chỉnh
try:
# Chỉ định rõ ràng rằng sử dụng định dạng safetensors
tokenizer = ProphetNetTokenizer.from_pretrained(model_dir)
model = ProphetNetForConditionalGeneration.from_pretrained(model_dir)
print("Tải mô hình và tokenizer từ thư mục đã tinh chỉnh thành công!")
except Exception as e:
raise RuntimeError(f"Lỗi khi tải mô hình/tokenizer: {e}. Vui lòng kiểm tra cấu trúc thư mục hoặc cập nhật thư viện transformers.")
# Tạo pipeline để tạo câu hỏi (question generation)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256, # Giới hạn độ dài tối đa của câu hỏi
num_return_sequences=1, # Tạo 1 câu hỏi duy nhất
device=0 if torch.cuda.is_available() else -1 # Sử dụng GPU nếu có, mặc định CPU
)
# Hàm tạo câu hỏi từ context và answer
def generate_question(context, answer):
# Định dạng input theo cách mô hình đã được tinh chỉnh
input_text = f"context: {context} answer: {answer}"
try:
result = pipe(input_text)[0]['generated_text']
return result
except Exception as e:
print(f"Lỗi khi tạo câu hỏi: {e}")
return None
# Thử nghiệm pipeline với một ví dụ
context = "The Vatican Apostolic Library is located in Vatican City."
answer = "Vatican City"
question = generate_question(context, answer)
if question:
print(f"Context: {context}")
print(f"Answer: {answer}")
print(f"Generated Question: {question}")
# (Tùy chọn) Kiểm tra với dữ liệu từ SQuAD
sample = train_ds[0] # Lấy mẫu đầu tiên từ tập dữ liệu
context_sample = sample['context']
answer_sample = sample['answers']['text'][0] if sample['answers']['text'] else "No answer"
question_sample = generate_question(context_sample, answer_sample)
if question_sample:
print(f"\nSample Context: {context_sample}")
print(f"Sample Answer: {answer_sample}")
print(f"Generated Question: {question_sample}")
# /Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/test.py