| | |
| | import os |
| | import gc |
| | import gradio as gr |
| | from datasets import load_dataset |
| | from train_tokenizer import train_tokenizer |
| | from tokenizers import Tokenizer |
| | from langdetect import detect, DetectorFactory |
| | from PIL import Image |
| | from datetime import datetime |
| | from concurrent.futures import ThreadPoolExecutor |
| | import matplotlib.pyplot as plt |
| | from io import BytesIO |
| | import traceback |
| |
|
| | |
| | DetectorFactory.seed = 0 |
| |
|
| | |
| | CHECKPOINT_FILE = "checkpoint.txt" |
| | TOKENIZER_DIR = "./tokenizer_model" |
| | TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json") |
| | MAX_SAMPLES = 5000000 |
| | DEFAULT_CHUNK_SIZE = 200000 |
| | BATCH_SIZE = 1000 |
| | NUM_WORKERS = 4 |
| |
|
| | |
| | STOP_COLLECTION = False |
| |
|
| | def load_checkpoint(): |
| | """Φόρτωση δεδομένων από το checkpoint.""" |
| | if os.path.exists(CHECKPOINT_FILE): |
| | with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: |
| | return f.read().splitlines() |
| | return [] |
| |
|
| | def append_to_checkpoint(texts): |
| | """Αποθήκευση δεδομένων με ομαδοποίηση.""" |
| | with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f: |
| | batch = "\n".join(texts) + "\n" |
| | f.write(batch) |
| |
|
| | def create_iterator(dataset_name, configs, split): |
| | """Βελτιωμένο iterator με batch φόρτωση και caching.""" |
| | configs_list = [c.strip() for c in configs.split(",") if c.strip()] |
| | for config in configs_list: |
| | try: |
| | dataset = load_dataset( |
| | dataset_name, |
| | name=config, |
| | split=split, |
| | streaming=True, |
| | cache_dir="./dataset_cache" |
| | ) |
| | while True: |
| | batch = list(dataset.take(BATCH_SIZE)) |
| | if not batch: |
| | break |
| | dataset = dataset.skip(BATCH_SIZE) |
| | with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: |
| | processed_texts = list(executor.map(process_example, batch)) |
| | yield from filter(None, processed_texts) |
| | except Exception as e: |
| | print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}") |
| |
|
| | def process_example(example): |
| | """Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας.""" |
| | try: |
| | text = example.get('text', '').strip() |
| | if text and detect(text) in ['el', 'en']: |
| | return text |
| | return None |
| | except: |
| | return None |
| |
|
| | def collect_samples(dataset_name, configs, split, chunk_size, max_samples): |
| | """Συλλογή δεδομένων με streaming και checkpoints.""" |
| | global STOP_COLLECTION |
| | STOP_COLLECTION = False |
| | total_processed = len(load_checkpoint()) |
| | progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"] |
| | dataset_iterator = create_iterator(dataset_name, configs, split) |
| | chunk = [] |
| | while not STOP_COLLECTION and total_processed < max_samples: |
| | try: |
| | while len(chunk) < chunk_size: |
| | text = next(dataset_iterator) |
| | if text: |
| | chunk.append(text) |
| | total_processed += 1 |
| | if total_processed >= max_samples: |
| | break |
| | if chunk: |
| | append_to_checkpoint(chunk) |
| | progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})") |
| | chunk = [] |
| | gc.collect() |
| | except StopIteration: |
| | progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!") |
| | break |
| | except Exception as e: |
| | progress_messages.append(f"⛔ Σφάλμα: {str(e)}") |
| | break |
| | return "\n".join(progress_messages) |
| |
|
| | def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text): |
| | """Εκπαίδευση του tokenizer και έλεγχος ποιότητας.""" |
| | messages = ["🚀 Εκκίνηση εκπαίδευσης..."] |
| | try: |
| | all_texts = load_checkpoint() |
| | messages.append("📚 Φόρτωση δεδομένων από checkpoint...") |
| | tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS) |
| | messages.append("✅ Εκπαίδευση ολοκληρώθηκε!") |
| | trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE) |
| | encoded = trained_tokenizer.encode(test_text) |
| | decoded = trained_tokenizer.decode(encoded.ids) |
| | fig, ax = plt.subplots() |
| | ax.hist([len(t) for t in encoded.tokens], bins=20) |
| | ax.set_xlabel('Μήκος Token') |
| | ax.set_ylabel('Συχνότητα') |
| | img_buffer = BytesIO() |
| | plt.savefig(img_buffer, format='png') |
| | plt.close() |
| | return ("\n".join(messages), decoded, Image.open(img_buffer)) |
| | except Exception as e: |
| | messages.append(f"❌ Σφάλμα: {str(e)}") |
| | return ("\n".join(messages), "", None) |
| |
|
| | def analyze_checkpoint(): |
| | """Ανάλυση δεδομένων από το checkpoint.""" |
| | messages = ["🔍 Έναρξη ανάλυσης..."] |
| | try: |
| | texts = load_checkpoint() |
| | if not texts: |
| | return "Δεν βρέθηκαν δεδομένα για ανάλυση." |
| | total_chars = sum(len(t) for t in texts) |
| | avg_length = total_chars / len(texts) if texts else 0 |
| | languages = {} |
| | for t in texts[:1000]: |
| | if len(t) > 20: |
| | try: |
| | lang = detect(t) |
| | languages[lang] = languages.get(lang, 0) + 1 |
| | except Exception as e: |
| | print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}") |
| | report = [ |
| | f"📊 Σύνολο δειγμάτων: {len(texts)}", |
| | f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες", |
| | "🌍 Γλώσσες (δείγμα 1000):", |
| | *[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()] |
| | ] |
| | return "\n".join(messages + report) |
| | except Exception as e: |
| | messages.append(f"❌ Σφάλμα: {str(e)}") |
| | return "\n".join(messages) |
| |
|
| | def restart_collection(): |
| | """Διαγραφή checkpoint και επανεκκίνηση.""" |
| | global STOP_COLLECTION |
| | STOP_COLLECTION = False |
| | if os.path.exists(CHECKPOINT_FILE): |
| | os.remove(CHECKPOINT_FILE) |
| | return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή." |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## Custom Tokenizer Trainer για GPT-2") |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset") |
| | configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations") |
| | split = gr.Dropdown(["train"], value="train", label="Split") |
| | chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size") |
| | vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου") |
| | min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα") |
| | test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text") |
| | max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα") |
| | with gr.Row(): |
| | start_btn = gr.Button("Start", variant="primary") |
| | stop_btn = gr.Button("Stop", variant="stop") |
| | restart_btn = gr.Button("Restart") |
| | analyze_btn = gr.Button("Analyze Data") |
| | train_btn = gr.Button("Train Tokenizer", variant="primary") |
| | with gr.Column(scale=3): |
| | progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False) |
| | gr.Markdown("### Αποτελέσματα") |
| | decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο") |
| | token_distribution = gr.Image(label="Κατανομή Tokens") |
| |
|
| | |
| | start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress) |
| | stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False) |
| | restart_btn.click(restart_collection, None, progress) |
| | analyze_btn.click(analyze_checkpoint, None, progress) |
| | train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text], |
| | [progress, decoded_text, token_distribution]) |
| |
|
| | demo.queue().launch() |