import gradio as gr import pandas as pd import requests import internetarchive from datetime import datetime import re import os import shutil import time import random import json import torch from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split import numpy as np import nest_asyncio import sys # --- SYSTEM FIXES --- try: nest_asyncio.apply() except Exception as e: print(f"Warning: Could not apply nest_asyncio: {e}") # --- CONFIGURATION --- DATASET_DIR = "dataset_ml_final_v2" BOOKS_DIR = os.path.join(DATASET_DIR, "books") MODEL_DIR = "trained_models" os.makedirs(MODEL_DIR, exist_ok=True) # --- TOKENIZER & MODEL --- TOKENIZER = None MODEL = None # Check for CUDA support for GPU, otherwise use CPU DEVICE = "cuda" if torch.cuda.is_available() else "cpu" try: from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, logging from torch.optim import AdamW logging.set_verbosity_error() print("Attempting to load Longformer Tokenizer...") TOKENIZER = AutoTokenizer.from_pretrained("allenai/longformer-base-4096") print("✅ Tokenizer loaded successfully.") except Exception as e: print(f"⚠️ Tokenizer loading error: {e}") AdamW = None # --- ERAS (10 Distinct Periods) --- # DATASET FIX: Updated Search Hints for better boundary distinction ERAS = [ (500, 1200, "0_Medieval", "Medieval OR Latin manuscript OR Anglo-Saxon prose"), (1200, 1470, "1_Late_Medieval", "Chaucer OR Middle English OR morality play"), (1470, 1650, "2_Early_Modern_Renaissance", "Shakespeare OR Bacon OR Protestant theology OR Early English Bible"), (1650, 1800, "3_Enlightenment_Classical", "Pope couplets OR Swift satire OR Neoclassical OR reason science"), (1800, 1850, "4_Romantic", "Byron OR Keats OR Shelley OR nature sublime emotion"), (1850, 1920, "5_Industrial_Victorian", "Dickens OR Industrial Age OR Darwinism OR social novel"), (1920, 1945, "6_Modernist", "Modernism OR stream of consciousness OR avant-garde fiction"), (1945, 1960, "7_Postwar_Early_Modern", "Postwar OR Early Cold War OR existentialism"), (1960, 1990, "8_Late_20th_Century", "Late 20th Century OR Postmodern OR Vietnam War"), (1990, 2024, "9_Contemporary_Information_Age", "Contemporary OR Digital era OR internet culture") ] ERA_LABELS = [era[2] for era in ERAS] LABEL_TO_ID = {label: idx for idx, label in enumerate(ERA_LABELS)} ID_TO_LABEL = {idx: label for idx, label in enumerate(ERA_LABELS)} # --- RESCUE KEYWORDS (Unchanged) --- RESCUE_KEYWORDS = { "0_Medieval": [ "Beowulf", "Bede", "Anglo Saxon Chronicle", "Cynewulf", "Caedmon", "Old English Homilies", "Aelfric", "Boethius", "Alfred the Great", "Venerable Bede", "Old English", "Anglo-Saxon poetry" ], "1_Late_Medieval": [ "Chaucer", "Canterbury Tales", "Piers Plowman", "Langland", "Gower", "Malory", "Morte d'Arthur", "Wycliffe", "Julian Norwich", "Margery Kempe", "Froissart", "Everyman", "Gawain", "Pearl Poet", "Lydgate", "Troilus Criseyde", "Book Duchess", "Parliament Fowls", "Legend Good Women", "Christine Pizan", "Romance Rose", "Confessio Amantis", "mystery plays", "miracle plays", "morality plays", "Middle English", "medieval romance", "medieval literature", "14th century literature", "15th century literature", "medieval poetry", "medieval drama", "Arthurian legend", "Chivalric romance", "Courtly love", "medieval manuscript", "Caxton", "medieval texts", "English medieval", "French medieval" ] } LATE_MEDIEVAL_COLLECTIONS = [ "gutenberg", "opensource", "medievaltexts", "earlyenglishbooksonline", "englishliterature", "medievalmanuscripts", "britishlibrary" ] # DATASET FIX: Added more contemporary-friendly topics TOPICS = [ "History", "Philosophy", "Science", "Mathematics", "Medicine", "Astronomy", "Physics", "Chemistry", "Biology", "Fiction", "Poetry", "Drama", "Mythology", "Folklore", "Religion", "Theology", "Biography", "Politics", "Economics", "Law", "Sociology", "Technology", "Engineering", "Travel", "War", "Military", "Art", "Psychology", "Anthropology", "Literature", "Essays", "Memoirs", "Education", "Computer Programming", "Digital Culture", "Current Affairs" ] # ============================================================================ # TAB 1: DATASET GENERATION # ============================================================================ def setup_dirs(): if os.path.exists(DATASET_DIR): try: shutil.rmtree(DATASET_DIR) except: pass os.makedirs(BOOKS_DIR, exist_ok=True) def text_quality_check(text): """ A light-weight quality check to filter out poor scan or boilerplate text. """ if len(text) < 3000: return False alpha_count = sum(c.isalpha() for c in text) total_count = len(text) if alpha_count / (total_count + 1e-6) < 0.5: return False start_snippet = text[:1000].lower() boilerplate_indicators = ["table of contents", "chapter i", "preface", "index", "list of figures"] if any(indicator in start_snippet for indicator in boilerplate_indicators): if len(text) < 10000: return False lines = text.split('\n') from collections import Counter line_counts = Counter(l.strip() for l in lines if l.strip()) if len(line_counts) < 50: return False frequent_lines = sum(1 for count in line_counts.values() if count >= 3) if frequent_lines / len(line_counts) > 0.1: return False return True def chunk_text_robust(text): MAX_TOKENS = 3500 STRIDE = 500 MAX_CHUNKS_PER_BOOK = 40 chunks = [] if TOKENIZER: try: tokens = TOKENIZER.encode(text, add_special_tokens=False) i = 0 while i < len(tokens) and len(chunks) < MAX_CHUNKS_PER_BOOK: chunk_ids = tokens[i : i + MAX_TOKENS] chunk_str = TOKENIZER.decode(chunk_ids, skip_special_tokens=True) chunks.append(chunk_str) i += (MAX_TOKENS - STRIDE) return chunks except: pass WORDS_PER_CHUNK = 2700 WORD_STRIDE = 400 words = text.split() i = 0 while i < len(words) and len(chunks) < MAX_CHUNKS_PER_BOOK: chunk_words = words[i : i + WORDS_PER_CHUNK] chunk_str = " ".join(chunk_words) if len(chunk_str) > 300: chunks.append(chunk_str) i += (WORDS_PER_CHUNK - WORD_STRIDE) return chunks # ⭐️ FIX: Ensuring clean_text_content is defined before download_book def clean_text_content(text): markers = [("*** START OF", "*** END OF")] for start_m, end_m in markers: s = text.find(start_m) e = text.find(end_m) if s != -1 and e != -1: text = text[s+len(start_m):e] break return text.strip() # MODIFIED download_book to accept a bypass flag def download_book(identifier, title, year, era_label, min_char_limit=5000, bypass_quality_check=False): urls = [ f"https://archive.org/download/{identifier}/{identifier}_djvu.txt", f"https://archive.org/download/{identifier}/{identifier}.txt" ] content = "" for url in urls: try: r = requests.get(url, timeout=15) if r.status_code == 200: content = r.text break except: pass content = clean_text_content(content) # <-- The line that was failing if len(content) < min_char_limit: return None if not bypass_quality_check: if not text_quality_check(content): return None safe_title = re.sub(r'[^a-zA-Z0-9]', '_', title)[:40] filename = f"{year}_{era_label}_{safe_title}_{identifier}.txt" with open(os.path.join(BOOKS_DIR, filename), "w", encoding="utf-8") as f: f.write(content) return { "title": title, "year": int(year), "era_label": era_label, "filename": filename, "char_count": len(content), "source": "Internet Archive" } def generate_dataset(total_books_needed, progress=gr.Progress()): setup_dirs() records = [] books_per_era = max(1, int(total_books_needed / len(ERAS))) for start_year, end_year, era_label, search_hint in ERAS: collected = 0 attempts = 0 era_topics = TOPICS.copy() random.shuffle(era_topics) rescue_list = RESCUE_KEYWORDS.get(era_label, []) is_hard_era = len(rescue_list) > 0 min_chars = 5000 bypass_qc = False # FIX: Specialized rules for Contemporary Era if era_label == "9_Contemporary_Information_Age": min_chars = 2000 # Lower character requirement bypass_qc = True # Disable strict quality check max_attempts = 40 # Increase max attempts for this hard era elif is_hard_era: min_chars = 1000 max_attempts = 50 if era_label == "1_Late_Medieval" else 20 else: max_attempts = 20 rescue_threshold = 0 if era_label == "1_Late_Medieval" else 3 progress(0, desc=f"Scraping Era: {era_label}") print(f"\n{'='*60}") print(f"Starting Era: {era_label} (Target: {books_per_era} books | Min Chars: {min_chars})") print(f"{'='*60}") while collected < books_per_era and attempts < max_attempts: attempts += 1 using_rescue = False if is_hard_era and attempts > rescue_threshold: using_rescue = True kw = random.choice(rescue_list) if era_label == "1_Late_Medieval": query_type = attempts % 6 if query_type == 0: query = f"title:({kw}) AND mediatype:texts" elif query_type == 1: query = f"({kw}) AND mediatype:texts AND language:eng" elif query_type == 2: query = f"subject:({kw}) AND mediatype:texts" elif query_type == 3: col = random.choice(LATE_MEDIEVAL_COLLECTIONS) query = f"({kw}) AND collection:({col}) AND mediatype:texts" elif query_type == 4: query = f"({kw}) AND date:[1200 TO 1900] AND mediatype:texts AND language:eng" else: query = f"{kw} mediatype:texts" else: if attempts % 3 == 0: query = f"title:({kw}) AND mediatype:texts" elif attempts % 3 == 1: query = f"({kw}) AND mediatype:texts AND language:eng" else: query = f"subject:({kw}) AND mediatype:texts" print(f" > 🛡️ Rescue Search #{attempts} ({era_label}): {kw}") else: if not era_topics: era_topics = TOPICS.copy() random.shuffle(era_topics) topic = era_topics.pop() query = f"(subject:{topic} OR {search_hint}) AND date:[{start_year} TO {end_year}] AND mediatype:texts AND language:eng" if end_year > 1928: query += " AND (licenseurl:* OR rights:creative commons OR collection:opensourcemedia)" print(f" > Standard Search #{attempts}: {topic} | Hint: {search_hint.split(' OR ')[0]}...") try: search_generator = internetarchive.search_items( query, sorts=['downloads desc'], fields=['identifier', 'title', 'date', 'year'] ) search_results_batch = [] max_check_per_query = (50 if is_hard_era or era_label == "9_Contemporary_Information_Age" else 15) for i, item in enumerate(search_generator): search_results_batch.append(item) if i >= max_check_per_query: break results_found = len(search_results_batch) for res in search_results_batch: if collected >= books_per_era: break id_ = res.get('identifier') raw_date = res.get('date') or res.get('year') year = str(raw_date)[:4] if raw_date else "0000" if not year.isdigit(): year = "0000" if not using_rescue: if not (start_year <= int(year) <= end_year): continue if any(r['filename'].endswith(f"{id_}.txt") for r in records): continue rec = download_book( id_, res.get('title', 'Unknown'), year, era_label, min_char_limit=min_chars, bypass_quality_check=bypass_qc ) if rec: rec['topic'] = "Classic" if using_rescue else topic records.append(rec) collected += 1 print(f" ✅ Saved ({collected}/{books_per_era}): {rec['title']} ({year}) | Chars: {rec['char_count']}") if results_found == 0: print(f" ⚠️ No results found for this query") except Exception as e: print(f" ❌ Search error: {e}") time.sleep(1) print(f"Completed {era_label}: {collected}/{books_per_era} books collected") # ... (Fallback logic for Late Medieval remains) ... if era_label == "1_Late_Medieval" and collected < books_per_era * 0.3: print(f"\n⚠️ EMERGENCY FALLBACK MODE for {era_label}") fallback_attempts = 0 fallback_terms = [ "medieval english", "middle english", "chaucer OR malory OR gower", "14th century OR 15th century", "medieval literature english", "arthurian romance", "medieval poetry english" ] while collected < books_per_era and fallback_attempts < len(fallback_terms): term = fallback_terms[fallback_attempts] fallback_attempts += 1 query = f"({term}) AND mediatype:texts" print(f" > 🚨 Fallback #{fallback_attempts}: {term}") try: search_generator = internetarchive.search_items(query, sorts=['downloads desc'], fields=['identifier', 'title', 'date', 'year']) fallback_batch = [] for i, item in enumerate(search_generator): fallback_batch.append(item) if i >= 100: break checked = 0 for res in fallback_batch: if collected >= books_per_era: break checked += 1 id_ = res.get('identifier') if any(r['filename'].endswith(f"{id_}.txt") for r in records): continue raw_date = res.get('date') or res.get('year') year = str(raw_date)[:4] if raw_date else "0000" if not year.isdigit(): year = "0000" rec = download_book( id_, res.get('title', 'Unknown'), year, era_label, min_char_limit=min_chars, bypass_quality_check=bypass_qc ) if rec: rec['topic'] = "Medieval" records.append(rec) collected += 1 print(f" ✅ FALLBACK Success ({collected}/{books_per_era}): {rec['title']} | Chars: {rec['char_count']}") except Exception as e: print(f" ❌ Fallback error: {e}") time.sleep(1) if not records: return None, pd.DataFrame(), pd.DataFrame() print("\n" + "="*60) print("Starting Robust Chunking...") print("="*60) progress(0.9, desc="Chunking Text...") longformer_rows = [] for r in records: file_path = os.path.join(BOOKS_DIR, r["filename"]) try: with open(file_path, "r", encoding="utf-8") as f: raw_text = f.read() chunks = chunk_text_robust(raw_text) for idx, chunk in enumerate(chunks): longformer_rows.append({ "text": chunk, "era_label": r["era_label"], "year": r["year"], "chunk_id": idx }) print(f" ✅ Chunked {r['title']}: {len(chunks)} chunks") except Exception as e: print(f" ❌ Error processing {r['filename']}: {e}") df_rows = pd.DataFrame(longformer_rows) if not df_rows.empty: split_stats = df_rows['era_label'].value_counts().reset_index() split_stats.columns = ['Era Label', 'Total Chunks'] split_stats['Est. Train (80%)'] = (split_stats['Total Chunks'] * 0.8).astype(int) split_stats['Est. Val (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int) split_stats['Est. Test (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int) split_stats['Status'] = split_stats['Est. Val (10%)'].apply(lambda x: "⚠️ LOW DATA" if x < 5 else "✅ OK") else: split_stats = pd.DataFrame() total_chunks = len(longformer_rows) avg_chunks = total_chunks / len(records) if records else 0 general_stats_df = pd.DataFrame({ "Metric": ["Total Books", "Total Training Examples", "Avg Examples/Book"], "Value": [len(records), total_chunks, f"{avg_chunks:.1f}"] }) pd.DataFrame(records).to_csv(os.path.join(DATASET_DIR, "metadata.csv"), index=False) jsonl_path = os.path.join(DATASET_DIR, "longformer_dataset.jsonl") with open(jsonl_path, "w", encoding="utf-8") as f: for row in longformer_rows: f.write(json.dumps(row, ensure_ascii=False) + "\n") timestamp = int(datetime.now().timestamp()) zip_filename = f"Analyzed_ML_Dataset_{timestamp}" shutil.make_archive(zip_filename, 'zip', DATASET_DIR) print("\n" + "="*60) print("Dataset Generation Complete! READY FOR RETRAINING.") print("="*60) return f"{zip_filename}.zip", general_stats_df, split_stats # ============================================================================ # TAB 2: TRAINING (No changes needed, already optimized for 4080 Super) # ============================================================================ class LongformerDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_length=4096): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) label = self.labels[idx] encoding = self.tokenizer( text, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } def train_model(dataset_path, epochs, batch_size, learning_rate, gradient_accumulation_steps, progress=gr.Progress()): global MODEL, TOKENIZER if not TOKENIZER: return "❌ Tokenizer not loaded. Please install transformers library.", None, None if not os.path.exists(dataset_path): return "❌ Dataset file not found. Please generate a dataset first.", None, None if batch_size < 1: return "❌ Error: Batch Size must be at least 1.", None, None if gradient_accumulation_steps < 1: return "❌ Error: Gradient Accumulation Steps must be at least 1.", None, None scaler = torch.cuda.amp.GradScaler() if DEVICE == "cuda" else None try: progress(0.1, desc="Loading dataset...") data = [] with open(dataset_path, 'r', encoding='utf-8') as f: for line in f: data.append(json.loads(line)) df = pd.DataFrame(data) texts = df['text'].tolist() labels = [LABEL_TO_ID[label] for label in df['era_label'].tolist()] progress(0.2, desc="Splitting data...") X_train, X_temp, y_train, y_temp = train_test_split(texts, labels, test_size=0.2, random_state=42, stratify=labels) X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp) train_dataset = LongformerDataset(X_train, y_train, TOKENIZER) val_dataset = LongformerDataset(X_val, y_val, TOKENIZER) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size) progress(0.3, desc="Initializing model...") MODEL = AutoModelForSequenceClassification.from_pretrained( "allenai/longformer-base-4096", num_labels=len(LABEL_TO_ID) ) MODEL.to(DEVICE) optimizer = AdamW(MODEL.parameters(), lr=learning_rate) total_batches = len(train_loader) total_training_steps = (total_batches // gradient_accumulation_steps) * epochs scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_training_steps) train_losses = [] val_accuracies = [] step_count = 0 for epoch in range(epochs): MODEL.train() total_loss = 0 for batch_idx, batch in enumerate(train_loader): progress_val = (0.3 + (epoch / epochs) * 0.6) + ((batch_idx / total_batches) / epochs * 0.6) progress(progress_val, desc=f"Training Epoch {epoch+1}/{epochs} (Batch {batch_idx+1}/{total_batches})") input_ids = batch['input_ids'].to(DEVICE) attention_mask = batch['attention_mask'].to(DEVICE) labels = batch['labels'].to(DEVICE) with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss = loss / gradient_accumulation_steps if scaler: scaler.scale(loss).backward() else: loss.backward() total_loss += loss.item() * gradient_accumulation_steps step_count += 1 if step_count % gradient_accumulation_steps == 0 or batch_idx == total_batches - 1: if scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0) scaler.step(optimizer) scaler.update() else: torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() MODEL.eval() correct = 0 total = 0 with torch.no_grad(): with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): for batch in val_loader: input_ids = batch['input_ids'].to(DEVICE) attention_mask = batch['attention_mask'].to(DEVICE) labels = batch['labels'].to(DEVICE) outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask) predictions = torch.argmax(outputs.logits, dim=1) correct += (predictions == labels).sum().item() total += labels.size(0) avg_loss = total_loss / total_batches val_acc = correct / total train_losses.append(avg_loss) val_accuracies.append(val_acc) print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}") progress(0.95, desc="Saving model...") timestamp = int(datetime.now().timestamp()) model_path = os.path.join(MODEL_DIR, f"longformer_era_classifier_{timestamp}") MODEL.save_pretrained(model_path) TOKENIZER.save_pretrained(model_path) metrics_df = pd.DataFrame({ "Epoch": list(range(1, epochs + 1)), "Training Loss": train_losses, "Validation Accuracy": [f"{acc:.4f}" for acc in val_accuracies] }) summary = f"✅ Training Complete!\nFinal Val Acc: {val_accuracies[-1]:.4f}\nModel saved to: {model_path}" return summary, metrics_df, model_path except RuntimeError as e: if 'out of memory' in str(e): if DEVICE == "cuda": torch.cuda.empty_cache() return f"❌ Training error: CUDA Out Of Memory. Try reducing the 'Batch Size' slider to 1, or increase 'Gradient Accumulation Steps'. Error: {str(e)}", None, None return f"❌ Training error: {str(e)}", None, None except Exception as e: return f"❌ Training error: {str(e)}", None, None # ============================================================================ # TAB 3: TESTING (No changes needed) # ============================================================================ def load_trained_model(model_path): global MODEL, TOKENIZER try: TOKENIZER = AutoTokenizer.from_pretrained(model_path) MODEL = AutoModelForSequenceClassification.from_pretrained(model_path) MODEL.to(DEVICE) MODEL.eval() return f"✅ Model loaded successfully from {model_path}" except Exception as e: return f"❌ Error loading model: {str(e)}" def predict_era(text, model_path): global MODEL, TOKENIZER if not MODEL or not TOKENIZER: if model_path and os.path.exists(model_path): load_result = load_trained_model(model_path) if "Error" in load_result: return load_result, None else: return "❌ No model loaded. Please train a model first or provide a valid model path.", None try: encoding = TOKENIZER( text, add_special_tokens=True, max_length=4096, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(DEVICE) attention_mask = encoding['attention_mask'].to(DEVICE) with torch.no_grad(): with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.softmax(logits, dim=1)[0] predicted_class = torch.argmax(probabilities).item() top_3_probs, top_3_indices = torch.topk(probabilities, 3) results = [] for idx, prob in zip(top_3_indices, top_3_probs): era_label = ID_TO_LABEL[idx.item()] confidence = prob.item() * 100 results.append({ "Era": era_label, "Confidence": f"{confidence:.2f}%" }) predicted_era = ID_TO_LABEL[predicted_class] result_text = f"🎯 **Predicted Era:** {predicted_era}\n\n**Confidence:** {probabilities[predicted_class].item()*100:.2f}%" return result_text, pd.DataFrame(results) except Exception as e: return f"❌ Prediction error: {str(e)}", None # ============================================================================ # GRADIO UI # ============================================================================ with gr.Blocks(title="Complete ML Pipeline") as demo: gr.Markdown("# 📚 Complete ML Pipeline: Dataset Generation, Training & Testing (RTX 4080 Super Optimized)") with gr.Tabs(): # TAB 1: Dataset Generation with gr.Tab("📊 Dataset Generation"): gr.Markdown("## Generate Historical Text Dataset") gr.Markdown(""" **DATA QUALITY FIX:** Contemporary Era (`9_...`) now has lower length requirements and a less strict quality check to compensate for scarce open-source post-1990 data. """) with gr.Row(): dataset_slider = gr.Slider(10, 500, step=10, value=100, label="Total Books to Collect (Max 500)") generate_btn = gr.Button("🚀 Generate Dataset (New Data Quality)", variant="primary", size="lg") dataset_download = gr.File(label="📥 Download Dataset ZIP") with gr.Row(): with gr.Column(): gr.Markdown("### General Summary") gen_stats = gr.Dataframe() with gr.Column(): gr.Markdown("### Class Balance Check") split_stats = gr.Dataframe() generate_btn.click( generate_dataset, inputs=[dataset_slider], outputs=[dataset_download, gen_stats, split_stats] ) # TAB 2: Training with gr.Tab("🎓 Model Training"): gr.Markdown("## Train Longformer Era Classifier") gr.Markdown(f""" **GPU OPTIMIZED:** Training now uses **Automatic Mixed Precision (FP16/AMP)** for the RTX 4080 Super. With 16GB VRAM, you can use a higher **Batch Size** (e.g., 4 or 8) and often set **Gradient Accumulation Steps** to 1. """) with gr.Row(): with gr.Column(): train_dataset_path = gr.Textbox( label="Dataset Path", value=os.path.join(DATASET_DIR, "longformer_dataset.jsonl"), placeholder="Path to dataset JSONL file" ) train_epochs = gr.Slider(1, 10, step=1, value=3, label="Epochs") train_batch = gr.Slider(1, 16, step=1, value=4, label="Batch Size (Memory Control)") train_accum = gr.Slider(1, 16, step=1, value=1, label="Gradient Accumulation Steps (Effective Batch Size)") train_lr = gr.Number(value=2e-5, label="Learning Rate") train_btn = gr.Button("🏋️ Start Training", variant="primary", size="lg") with gr.Column(): train_output = gr.Textbox(label="Training Status", lines=8) train_metrics = gr.Dataframe(label="Training Metrics") model_path_output = gr.Textbox(label="Saved Model Path") train_btn.click( train_model, inputs=[train_dataset_path, train_epochs, train_batch, train_lr, train_accum], outputs=[train_output, train_metrics, model_path_output] ) # TAB 3: Testing with gr.Tab("🧪 Model Testing"): gr.Markdown("## Test Era Classification (FP16/AMP Inference)") with gr.Row(): with gr.Column(): test_model_path = gr.Textbox( label="Model Path (optional - uses last trained model if empty)", placeholder="trained_models/longformer_era_classifier_..." ) test_input = gr.Textbox( label="Input Text", lines=10, placeholder="Paste historical text here...\n\nExample: 'When that Aprille with his shoures soote, The droghte of Marche hath perced to the roote...'" ) test_btn = gr.Button("🔍 Predict Era", variant="primary", size="lg") with gr.Column(): test_result = gr.Markdown(label="Prediction Result") test_probabilities = gr.Dataframe(label="Top 3 Predictions") # Sample texts gr.Markdown("### 📝 Try Sample Texts") with gr.Row(): sample1 = gr.Button("Medieval Sample") sample2 = gr.Button("Victorian Sample") sample3 = gr.Button("Contemporary Sample") def load_medieval(): return "Hwæt! We Gardena in geardagum, þeodcyninga, þrym gefrunon, hu ða æþelingas ellen fremedon." def load_victorian(): return "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife." def load_contemporary(): return "The internet has fundamentally transformed how we communicate, work, and access information in the digital age." sample1.click(load_medieval, outputs=[test_input]) sample2.click(load_victorian, outputs=[test_input]) sample3.click(load_contemporary, outputs=[test_input]) test_btn.click( predict_era, inputs=[test_input, test_model_path], outputs=[test_result, test_probabilities] ) gr.Markdown("---") gr.Markdown(f"**Device:** {DEVICE} | **Status:** {'✅ CUDA/FP16 Ready' if DEVICE == 'cuda' else '⚠️ CPU Mode'} | **Model:** Longformer-base-4096") if __name__ == "__main__": demo.launch(ssr_mode=False)