|
|
import os |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from datasets import Dataset |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSeq2SeqLM, |
|
|
DataCollatorForSeq2Seq, |
|
|
Seq2SeqTrainingArguments, |
|
|
Seq2SeqTrainer |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent.parent |
|
|
DATA_ROOT = BASE_DIR / "content" |
|
|
TITLE_ROOT = BASE_DIR / "titles" |
|
|
SUMMARY_ROOT = BASE_DIR / "summary" |
|
|
|
|
|
MODEL_ID = "google/flan-t5-small" |
|
|
OUTPUT_MODEL_DIR = BASE_DIR / "summarizer" / "models" / "flan_t5_custom" |
|
|
|
|
|
MAX_INPUT_LEN = 512 |
|
|
MAX_TARGET_LEN = 128 |
|
|
|
|
|
def load_data(): |
|
|
"""Wczytuje dane i tworzy pary: Instrukcja + Tekst -> Wynik.""" |
|
|
dataset_dict = {"input_text": [], "target_text": []} |
|
|
|
|
|
print(f"📂 Szukam danych w: {DATA_ROOT}") |
|
|
|
|
|
|
|
|
files = list(DATA_ROOT.rglob("*.txt")) |
|
|
for txt_file in files: |
|
|
rel_path = txt_file.relative_to(DATA_ROOT) |
|
|
|
|
|
|
|
|
with open(txt_file, "r", encoding="utf-8") as f: |
|
|
ocr_content = f.read().strip() |
|
|
|
|
|
if not ocr_content: continue |
|
|
|
|
|
|
|
|
t_file = TITLE_ROOT / rel_path |
|
|
if t_file.exists(): |
|
|
with open(t_file, "r", encoding="utf-8") as f: |
|
|
dataset_dict["input_text"].append(f"headline: {ocr_content}") |
|
|
dataset_dict["target_text"].append(f.read().strip()) |
|
|
|
|
|
|
|
|
s_file = SUMMARY_ROOT / rel_path |
|
|
if s_file.exists(): |
|
|
with open(s_file, "r", encoding="utf-8") as f: |
|
|
dataset_dict["input_text"].append(f"summarize: {ocr_content}") |
|
|
dataset_dict["target_text"].append(f.read().strip()) |
|
|
|
|
|
return Dataset.from_dict(dataset_dict) |
|
|
|
|
|
def main(): |
|
|
|
|
|
raw_dataset = load_data() |
|
|
if len(raw_dataset) == 0: |
|
|
print("❌ Nie znaleziono plików w content/titles/summary. Sprawdź ścieżki.") |
|
|
return |
|
|
|
|
|
dataset = raw_dataset.train_test_split(test_size=0.1) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) |
|
|
|
|
|
def preprocess(examples): |
|
|
inputs = [ex for ex in examples["input_text"]] |
|
|
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LEN, truncation=True, padding="max_length") |
|
|
|
|
|
labels = tokenizer(text_target=examples["target_text"], max_length=MAX_TARGET_LEN, truncation=True, padding="max_length") |
|
|
model_inputs["labels"] = labels["input_ids"] |
|
|
return model_inputs |
|
|
|
|
|
tokenized_dataset = dataset.map(preprocess, batched=True) |
|
|
|
|
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
|
output_dir="./tmp_results", |
|
|
eval_strategy="epoch", |
|
|
learning_rate=3e-4, |
|
|
per_device_train_batch_size=8, |
|
|
per_device_eval_batch_size=8, |
|
|
weight_decay=0.01, |
|
|
save_total_limit=2, |
|
|
num_train_epochs=15, |
|
|
predict_with_generate=True, |
|
|
fp16=False, |
|
|
logging_steps=10, |
|
|
|
|
|
generation_max_length=MAX_TARGET_LEN, |
|
|
generation_num_beams=4, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset["train"], |
|
|
eval_dataset=tokenized_dataset["test"], |
|
|
tokenizer=tokenizer, |
|
|
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model), |
|
|
) |
|
|
|
|
|
print(f"🚀 Rozpoczynam uczenie na {len(raw_dataset)} przykładach...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_MODEL_DIR, exist_ok=True) |
|
|
model.save_pretrained(OUTPUT_MODEL_DIR) |
|
|
tokenizer.save_pretrained(OUTPUT_MODEL_DIR) |
|
|
print(f"✨ Model wyuczony i zapisany w: {OUTPUT_MODEL_DIR}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |