File size: 4,141 Bytes
4721322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
from pathlib import Path
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    DataCollatorForSeq2Seq, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer
)

# --- KONFIGURACJA ŚCIEŻEK ---
# Wyjście o jeden poziom wyżej z folderu 'summarizer' do głównego folderu projektu
BASE_DIR = Path(__file__).resolve().parent.parent
DATA_ROOT = BASE_DIR / "content"
TITLE_ROOT = BASE_DIR / "titles"
SUMMARY_ROOT = BASE_DIR / "summary"

MODEL_ID = "google/flan-t5-small"
OUTPUT_MODEL_DIR = BASE_DIR / "summarizer" / "models" / "flan_t5_custom"

MAX_INPUT_LEN = 512
MAX_TARGET_LEN = 128

def load_data():
    """Wczytuje dane i tworzy pary: Instrukcja + Tekst -> Wynik."""
    dataset_dict = {"input_text": [], "target_text": []}
    
    print(f"📂 Szukam danych w: {DATA_ROOT}")
    
    # Przeszukujemy foldery rekurencyjnie
    files = list(DATA_ROOT.rglob("*.txt"))
    for txt_file in files:
        rel_path = txt_file.relative_to(DATA_ROOT)
        
        # 1. Wczytaj surowy tekst (cecha wejściowa)
        with open(txt_file, "r", encoding="utf-8") as f:
            ocr_content = f.read().strip()
        
        if not ocr_content: continue

        # 2. Dodaj parę dla zadania HEADLINE
        t_file = TITLE_ROOT / rel_path
        if t_file.exists():
            with open(t_file, "r", encoding="utf-8") as f:
                dataset_dict["input_text"].append(f"headline: {ocr_content}")
                dataset_dict["target_text"].append(f.read().strip())
        
        # 3. Dodaj parę dla zadania SUMMARIZE
        s_file = SUMMARY_ROOT / rel_path
        if s_file.exists():
            with open(s_file, "r", encoding="utf-8") as f:
                dataset_dict["input_text"].append(f"summarize: {ocr_content}")
                dataset_dict["target_text"].append(f.read().strip())
                
    return Dataset.from_dict(dataset_dict)

def main():
    # 1. Przygotowanie danych
    raw_dataset = load_data()
    if len(raw_dataset) == 0:
        print("❌ Nie znaleziono plików w content/titles/summary. Sprawdź ścieżki.")
        return
        
    dataset = raw_dataset.train_test_split(test_size=0.1)
    
    # 2. Tokenizer i Model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

    def preprocess(examples):
        inputs = [ex for ex in examples["input_text"]]
        model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LEN, truncation=True, padding="max_length")
        
        labels = tokenizer(text_target=examples["target_text"], max_length=MAX_TARGET_LEN, truncation=True, padding="max_length")
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    tokenized_dataset = dataset.map(preprocess, batched=True)

    # 3. Argumenty treningu
    # 3. Argumenty treningu
    training_args = Seq2SeqTrainingArguments(
        output_dir="./tmp_results",
        eval_strategy="epoch",  # <--- Zmieniono z evaluation_strategy
        learning_rate=3e-4,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        weight_decay=0.01,
        save_total_limit=2,
        num_train_epochs=15,
        predict_with_generate=True,
        fp16=False,
        logging_steps=10,
        # Opcjonalnie dodaj te parametry dla lepszego generowania:
        generation_max_length=MAX_TARGET_LEN,
        generation_num_beams=4,
    )

    # 4. Trener
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        tokenizer=tokenizer,
        data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    )

    print(f"🚀 Rozpoczynam uczenie na {len(raw_dataset)} przykładach...")
    trainer.train()

    # 5. Zapisywanie modelu
    os.makedirs(OUTPUT_MODEL_DIR, exist_ok=True)
    model.save_pretrained(OUTPUT_MODEL_DIR)
    tokenizer.save_pretrained(OUTPUT_MODEL_DIR)
    print(f"✨ Model wyuczony i zapisany w: {OUTPUT_MODEL_DIR}")

if __name__ == "__main__":
    main()