|
|
|
|
|
|
|
|
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline |
|
|
import torch |
|
|
import os |
|
|
import pandas as pd |
|
|
import time |
|
|
from datasets import Dataset |
|
|
|
|
|
|
|
|
def load_squad_parquet(split='train', max_retries=3, delay=5): |
|
|
splits = {'train': 'plain_text/train-00000-of-00001.parquet'} |
|
|
path = "hf://datasets/rajpurkar/squad/" + splits[split] |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
df = pd.read_parquet(path) |
|
|
print(f"Tải tập dữ liệu SQuAD {split} thành công sau {attempt + 1} lần thử!") |
|
|
return df |
|
|
except Exception as e: |
|
|
print(f"Lần thử {attempt + 1}/{max_retries} thất bại: {e}") |
|
|
if attempt < max_retries - 1: |
|
|
print(f"Đợi {delay} giây trước khi thử lại...") |
|
|
time.sleep(delay) |
|
|
else: |
|
|
print("Đã hết số lần thử. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.") |
|
|
return None |
|
|
|
|
|
|
|
|
train_df = load_squad_parquet('train') |
|
|
if train_df is None: |
|
|
raise ValueError("Không thể tải tập dữ liệu SQuAD. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.") |
|
|
|
|
|
|
|
|
train_ds = Dataset.from_pandas(train_df) |
|
|
|
|
|
|
|
|
model_dir = "/Users/trantieuman/Downloads/prophetnet_1epoch/prophetnet_context_to_question_finetuned" |
|
|
|
|
|
|
|
|
if not os.path.exists(model_dir): |
|
|
raise FileNotFoundError(f"Thư mục {model_dir} không tồn tại. Vui lòng kiểm tra lại đường dẫn.") |
|
|
|
|
|
|
|
|
required_model_files = ['config.json', 'model.safetensors'] |
|
|
required_tokenizer_files = ['prophetnet.tokenizer', 'tokenizer_config.json'] |
|
|
all_files = os.listdir(model_dir) |
|
|
missing_model_files = [f for f in required_model_files if f not in all_files] |
|
|
missing_tokenizer_files = [f for f in required_tokenizer_files if f not in all_files] |
|
|
|
|
|
if missing_model_files or missing_tokenizer_files: |
|
|
print(f"Thiếu file trong {model_dir}:") |
|
|
if missing_model_files: |
|
|
print(f" - File mô hình thiếu: {missing_model_files}") |
|
|
if missing_tokenizer_files: |
|
|
print(f" - File tokenizer thiếu: {missing_tokenizer_files}") |
|
|
raise FileNotFoundError("Vui lòng cung cấp đầy đủ file mô hình và tokenizer.") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = ProphetNetTokenizer.from_pretrained(model_dir) |
|
|
model = ProphetNetForConditionalGeneration.from_pretrained(model_dir) |
|
|
print("Tải mô hình và tokenizer từ thư mục đã tinh chỉnh thành công!") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Lỗi khi tải mô hình/tokenizer: {e}. Vui lòng kiểm tra cấu trúc thư mục hoặc cập nhật thư viện transformers.") |
|
|
|
|
|
|
|
|
pipe = pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=256, |
|
|
num_return_sequences=1, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
|
|
|
def generate_question(context, answer): |
|
|
|
|
|
input_text = f"context: {context} answer: {answer}" |
|
|
try: |
|
|
result = pipe(input_text)[0]['generated_text'] |
|
|
return result |
|
|
except Exception as e: |
|
|
print(f"Lỗi khi tạo câu hỏi: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
context = "The Vatican Apostolic Library is located in Vatican City." |
|
|
answer = "Vatican City" |
|
|
question = generate_question(context, answer) |
|
|
if question: |
|
|
print(f"Context: {context}") |
|
|
print(f"Answer: {answer}") |
|
|
print(f"Generated Question: {question}") |
|
|
|
|
|
|
|
|
sample = train_ds[0] |
|
|
context_sample = sample['context'] |
|
|
answer_sample = sample['answers']['text'][0] if sample['answers']['text'] else "No answer" |
|
|
question_sample = generate_question(context_sample, answer_sample) |
|
|
if question_sample: |
|
|
print(f"\nSample Context: {context_sample}") |
|
|
print(f"Sample Answer: {answer_sample}") |
|
|
print(f"Generated Question: {question_sample}") |
|
|
|
|
|
|