Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| import internetarchive | |
| from datetime import datetime | |
| import re | |
| import os | |
| import shutil | |
| import time | |
| import random | |
| import json | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from sklearn.model_selection import train_test_split | |
| import numpy as np | |
| import nest_asyncio | |
| import sys | |
| # --- SYSTEM FIXES --- | |
| try: | |
| nest_asyncio.apply() | |
| except Exception as e: | |
| print(f"Warning: Could not apply nest_asyncio: {e}") | |
| # --- CONFIGURATION --- | |
| DATASET_DIR = "dataset_ml_final_v2" | |
| BOOKS_DIR = os.path.join(DATASET_DIR, "books") | |
| MODEL_DIR = "trained_models" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| # --- TOKENIZER & MODEL --- | |
| TOKENIZER = None | |
| MODEL = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, logging | |
| from torch.optim import AdamW | |
| logging.set_verbosity_error() | |
| print("Attempting to load Longformer Tokenizer...") | |
| TOKENIZER = AutoTokenizer.from_pretrained("allenai/longformer-base-4096") | |
| print("✅ Tokenizer loaded successfully.") | |
| except Exception as e: | |
| print(f"⚠️ Tokenizer loading error: {e}") | |
| AdamW = None | |
| # --- ERAS (10 Distinct Periods) --- | |
| ERAS = [ | |
| (500, 1200, "0_Medieval", "Medieval OR Middle Ages OR Latin manuscripts"), | |
| (1200, 1470, "1_Late_Medieval", "Middle English OR Old French OR Troubadour"), | |
| (1470, 1650, "2_Early_Modern_Renaissance", "Renaissance OR Early Modern"), | |
| (1650, 1800, "3_Enlightenment_Classical", "Enlightenment OR Classical literature"), | |
| (1800, 1850, "4_Romantic", "Romanticism OR Romantic period"), | |
| (1850, 1920, "5_Industrial_Victorian", "Victorian OR Industrial Age"), | |
| (1920, 1945, "6_Modernist", "Modernism OR Avant-garde"), | |
| (1945, 1960, "7_Postwar_Early_Modern", "Postwar OR Early Cold War"), | |
| (1960, 1990, "8_Late_20th_Century", "Late 20th Century OR Postmodern"), | |
| (1990, 2024, "9_Contemporary_Information_Age", "Contemporary OR Digital era") | |
| ] | |
| ERA_LABELS = [era[2] for era in ERAS] | |
| LABEL_TO_ID = {label: idx for idx, label in enumerate(ERA_LABELS)} | |
| ID_TO_LABEL = {idx: label for idx, label in enumerate(ERA_LABELS)} | |
| # --- RESCUE KEYWORDS --- | |
| RESCUE_KEYWORDS = { | |
| "0_Medieval": [ | |
| "Beowulf", "Bede", "Anglo Saxon Chronicle", "Cynewulf", "Caedmon", | |
| "Old English Homilies", "Aelfric", "Boethius", "Alfred the Great", | |
| "Venerable Bede", "Old English", "Anglo-Saxon poetry" | |
| ], | |
| "1_Late_Medieval": [ | |
| "Chaucer", "Canterbury Tales", "Piers Plowman", "Langland", | |
| "Gower", "Malory", "Morte d'Arthur", "Wycliffe", | |
| "Julian Norwich", "Margery Kempe", "Froissart", "Everyman", | |
| "Gawain", "Pearl Poet", "Lydgate", "Troilus Criseyde", | |
| "Book Duchess", "Parliament Fowls", "Legend Good Women", | |
| "Christine Pizan", "Romance Rose", "Confessio Amantis", | |
| "mystery plays", "miracle plays", "morality plays", | |
| "Middle English", "medieval romance", "medieval literature", | |
| "14th century literature", "15th century literature", | |
| "medieval poetry", "medieval drama", "Arthurian legend", | |
| "Chivalric romance", "Courtly love", "medieval manuscript", | |
| "Caxton", "medieval texts", "English medieval", "French medieval" | |
| ] | |
| } | |
| LATE_MEDIEVAL_COLLECTIONS = [ | |
| "gutenberg", "opensource", "medievaltexts", "earlyenglishbooksonline", | |
| "englishliterature", "medievalmanuscripts", "britishlibrary" | |
| ] | |
| TOPICS = [ | |
| "History", "Philosophy", "Science", "Mathematics", "Medicine", "Astronomy", | |
| "Physics", "Chemistry", "Biology", "Fiction", "Poetry", "Drama", "Mythology", | |
| "Folklore", "Religion", "Theology", "Biography", "Politics", "Economics", "Law", | |
| "Sociology", "Technology", "Engineering", "Travel", "War", "Military", "Art", | |
| "Psychology", "Anthropology", "Literature", "Essays", "Memoirs", "Education" | |
| ] | |
| # ============================================================================ | |
| # TAB 1: DATASET GENERATION | |
| # ============================================================================ | |
| def setup_dirs(): | |
| if os.path.exists(DATASET_DIR): | |
| try: shutil.rmtree(DATASET_DIR) | |
| except: pass | |
| os.makedirs(BOOKS_DIR, exist_ok=True) | |
| def chunk_text_robust(text): | |
| MAX_TOKENS = 3500 | |
| STRIDE = 500 | |
| MAX_CHUNKS_PER_BOOK = 40 | |
| chunks = [] | |
| if TOKENIZER: | |
| try: | |
| tokens = TOKENIZER.encode(text, add_special_tokens=False) | |
| i = 0 | |
| while i < len(tokens) and len(chunks) < MAX_CHUNKS_PER_BOOK: | |
| chunk_ids = tokens[i : i + MAX_TOKENS] | |
| chunk_str = TOKENIZER.decode(chunk_ids, skip_special_tokens=True) | |
| chunks.append(chunk_str) | |
| i += (MAX_TOKENS - STRIDE) | |
| return chunks | |
| except: pass | |
| WORDS_PER_CHUNK = 2700 | |
| WORD_STRIDE = 400 | |
| words = text.split() | |
| i = 0 | |
| while i < len(words) and len(chunks) < MAX_CHUNKS_PER_BOOK: | |
| chunk_words = words[i : i + WORDS_PER_CHUNK] | |
| chunk_str = " ".join(chunk_words) | |
| if len(chunk_str) > 300: | |
| chunks.append(chunk_str) | |
| i += (WORDS_PER_CHUNK - WORD_STRIDE) | |
| return chunks | |
| def clean_text_content(text): | |
| markers = [("*** START OF", "*** END OF")] | |
| for start_m, end_m in markers: | |
| s = text.find(start_m) | |
| e = text.find(end_m) | |
| if s != -1 and e != -1: | |
| text = text[s+len(start_m):e] | |
| break | |
| return text.strip() | |
| def download_book(identifier, title, year, era_label, min_char_limit=5000): | |
| urls = [ | |
| f"https://archive.org/download/{identifier}/{identifier}_djvu.txt", | |
| f"https://archive.org/download/{identifier}/{identifier}.txt" | |
| ] | |
| content = "" | |
| for url in urls: | |
| try: | |
| r = requests.get(url, timeout=15) | |
| if r.status_code == 200: | |
| content = r.text | |
| break | |
| except: pass | |
| content = clean_text_content(content) | |
| if len(content) < min_char_limit: | |
| return None | |
| safe_title = re.sub(r'[^a-zA-Z0-9]', '_', title)[:40] | |
| filename = f"{year}_{era_label}_{safe_title}_{identifier}.txt" | |
| with open(os.path.join(BOOKS_DIR, filename), "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return { | |
| "title": title, "year": int(year), "era_label": era_label, | |
| "filename": filename, "char_count": len(content), "source": "Internet Archive" | |
| } | |
| def generate_dataset(total_books_needed, progress=gr.Progress()): | |
| setup_dirs() | |
| records = [] | |
| books_per_era = max(1, int(total_books_needed / len(ERAS))) | |
| for start_year, end_year, era_label, search_hint in ERAS: | |
| collected = 0 | |
| attempts = 0 | |
| era_topics = TOPICS.copy() | |
| random.shuffle(era_topics) | |
| rescue_list = RESCUE_KEYWORDS.get(era_label, []) | |
| is_hard_era = len(rescue_list) > 0 | |
| min_chars = 1000 if is_hard_era else 5000 | |
| max_attempts = 80 if era_label == "1_Late_Medieval" else (50 if is_hard_era else 20) | |
| rescue_threshold = 0 if era_label == "1_Late_Medieval" else 3 | |
| progress(0, desc=f"Scraping Era: {era_label}") | |
| print(f"\n{'='*60}") | |
| print(f"Starting Era: {era_label} (Target: {books_per_era} books)") | |
| print(f"{'='*60}") | |
| while collected < books_per_era and attempts < max_attempts: | |
| attempts += 1 | |
| using_rescue = False | |
| if is_hard_era and attempts > rescue_threshold: | |
| using_rescue = True | |
| kw = random.choice(rescue_list) | |
| if era_label == "1_Late_Medieval": | |
| query_type = attempts % 6 | |
| if query_type == 0: | |
| query = f"title:({kw}) AND mediatype:texts" | |
| elif query_type == 1: | |
| query = f"({kw}) AND mediatype:texts AND language:eng" | |
| elif query_type == 2: | |
| query = f"subject:({kw}) AND mediatype:texts" | |
| elif query_type == 3: | |
| col = random.choice(LATE_MEDIEVAL_COLLECTIONS) | |
| query = f"({kw}) AND collection:({col}) AND mediatype:texts" | |
| elif query_type == 4: | |
| query = f"({kw}) AND date:[1200 TO 1900] AND mediatype:texts AND language:eng" | |
| else: | |
| query = f"{kw} mediatype:texts" | |
| else: | |
| if attempts % 3 == 0: | |
| query = f"title:({kw}) AND mediatype:texts" | |
| elif attempts % 3 == 1: | |
| query = f"({kw}) AND mediatype:texts AND language:eng" | |
| else: | |
| query = f"subject:({kw}) AND mediatype:texts" | |
| print(f" > 🛡️ Rescue Search #{attempts} ({era_label}): {kw}") | |
| else: | |
| if not era_topics: | |
| era_topics = TOPICS.copy() | |
| random.shuffle(era_topics) | |
| topic = era_topics.pop() | |
| query = f"(subject:{topic} OR {search_hint}) AND date:[{start_year} TO {end_year}] AND mediatype:texts AND language:eng" | |
| if end_year > 1928: | |
| query += " AND (licenseurl:* OR rights:creative commons OR collection:opensourcemedia)" | |
| print(f" > Standard Search #{attempts}: {topic}") | |
| try: | |
| search_generator = internetarchive.search_items( | |
| query, | |
| sorts=['downloads desc'], | |
| fields=['identifier', 'title', 'date', 'year'] | |
| ) | |
| # ⭐️ FIX 1: Pre-fetch a batch of results to close the search connection quickly | |
| search_results_batch = [] | |
| max_check_per_query = (50 if era_label == "1_Late_Medieval" else (30 if is_hard_era else 10)) | |
| for i, item in enumerate(search_generator): | |
| search_results_batch.append(item) | |
| if i >= max_check_per_query: break | |
| results_found = len(search_results_batch) | |
| for res in search_results_batch: | |
| if collected >= books_per_era: break | |
| id_ = res.get('identifier') | |
| raw_date = res.get('date') or res.get('year') | |
| year = str(raw_date)[:4] if raw_date else "0000" | |
| if not year.isdigit(): year = "0000" | |
| if not using_rescue: | |
| if not (start_year <= int(year) <= end_year): | |
| continue | |
| if any(r['filename'].endswith(f"{id_}.txt") for r in records): | |
| continue | |
| rec = download_book(id_, res.get('title', 'Unknown'), year, era_label, min_char_limit=min_chars) | |
| if rec: | |
| rec['topic'] = "Classic" if using_rescue else topic | |
| records.append(rec) | |
| collected += 1 | |
| print(f" ✅ Saved ({collected}/{books_per_era}): {rec['title']} ({year})") | |
| if results_found == 0: | |
| print(f" ⚠️ No results found for this query") | |
| except Exception as e: | |
| print(f" ❌ Search error: {e}") | |
| time.sleep(1) | |
| print(f"Completed {era_label}: {collected}/{books_per_era} books collected") | |
| if era_label == "1_Late_Medieval" and collected < books_per_era * 0.3: | |
| print(f"\n⚠️ EMERGENCY FALLBACK MODE for {era_label}") | |
| fallback_attempts = 0 | |
| fallback_terms = [ | |
| "medieval english", "middle english", "chaucer OR malory OR gower", | |
| "14th century OR 15th century", "medieval literature english", | |
| "arthurian romance", "medieval poetry english" | |
| ] | |
| while collected < books_per_era and fallback_attempts < len(fallback_terms): | |
| term = fallback_terms[fallback_attempts] | |
| fallback_attempts += 1 | |
| query = f"({term}) AND mediatype:texts" | |
| print(f" > 🚨 Fallback #{fallback_attempts}: {term}") | |
| try: | |
| search_generator = internetarchive.search_items(query, sorts=['downloads desc'], fields=['identifier', 'title', 'date', 'year']) | |
| fallback_batch = [] | |
| for i, item in enumerate(search_generator): | |
| fallback_batch.append(item) | |
| if i >= 100: break | |
| checked = 0 | |
| for res in fallback_batch: | |
| if collected >= books_per_era: | |
| break | |
| checked += 1 | |
| id_ = res.get('identifier') | |
| if any(r['filename'].endswith(f"{id_}.txt") for r in records): | |
| continue | |
| raw_date = res.get('date') or res.get('year') | |
| year = str(raw_date)[:4] if raw_date else "0000" | |
| if not year.isdigit(): year = "0000" | |
| rec = download_book(id_, res.get('title', 'Unknown'), year, era_label, min_char_limit=min_chars) | |
| if rec: | |
| rec['topic'] = "Medieval" | |
| records.append(rec) | |
| collected += 1 | |
| print(f" ✅ FALLBACK Success ({collected}/{books_per_era}): {rec['title']}") | |
| except Exception as e: | |
| print(f" ❌ Fallback error: {e}") | |
| time.sleep(1) | |
| if not records: return None, pd.DataFrame(), pd.DataFrame() | |
| print("\n" + "="*60) | |
| print("Starting Robust Chunking...") | |
| print("="*60) | |
| progress(0.9, desc="Chunking Text...") | |
| longformer_rows = [] | |
| for r in records: | |
| file_path = os.path.join(BOOKS_DIR, r["filename"]) | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| raw_text = f.read() | |
| chunks = chunk_text_robust(raw_text) | |
| for idx, chunk in enumerate(chunks): | |
| longformer_rows.append({ | |
| "text": chunk, | |
| "era_label": r["era_label"], | |
| "year": r["year"], | |
| "chunk_id": idx | |
| }) | |
| print(f" ✅ Chunked {r['title']}: {len(chunks)} chunks") | |
| except Exception as e: | |
| print(f" ❌ Error processing {r['filename']}: {e}") | |
| df_rows = pd.DataFrame(longformer_rows) | |
| if not df_rows.empty: | |
| split_stats = df_rows['era_label'].value_counts().reset_index() | |
| split_stats.columns = ['Era Label', 'Total Chunks'] | |
| split_stats['Est. Train (80%)'] = (split_stats['Total Chunks'] * 0.8).astype(int) | |
| split_stats['Est. Val (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int) | |
| split_stats['Est. Test (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int) | |
| split_stats['Status'] = split_stats['Est. Val (10%)'].apply(lambda x: "⚠️ LOW DATA" if x < 5 else "✅ OK") | |
| else: | |
| split_stats = pd.DataFrame() | |
| total_chunks = len(longformer_rows) | |
| avg_chunks = total_chunks / len(records) if records else 0 | |
| general_stats_df = pd.DataFrame({ | |
| "Metric": ["Total Books", "Total Training Examples", "Avg Examples/Book"], | |
| "Value": [len(records), total_chunks, f"{avg_chunks:.1f}"] | |
| }) | |
| pd.DataFrame(records).to_csv(os.path.join(DATASET_DIR, "metadata.csv"), index=False) | |
| jsonl_path = os.path.join(DATASET_DIR, "longformer_dataset.jsonl") | |
| with open(jsonl_path, "w", encoding="utf-8") as f: | |
| for row in longformer_rows: | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| timestamp = int(datetime.now().timestamp()) | |
| zip_filename = f"Analyzed_ML_Dataset_{timestamp}" | |
| shutil.make_archive(zip_filename, 'zip', DATASET_DIR) | |
| print("\n" + "="*60) | |
| print("Dataset Generation Complete!") | |
| print("="*60) | |
| return f"{zip_filename}.zip", general_stats_df, split_stats | |
| # ============================================================================ | |
| # TAB 2: TRAINING | |
| # ============================================================================ | |
| class LongformerDataset(Dataset): | |
| def __init__(self, texts, labels, tokenizer, max_length=4096): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = str(self.texts[idx]) | |
| label = self.labels[idx] | |
| encoding = self.tokenizer( | |
| text, | |
| add_special_tokens=True, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].flatten(), | |
| 'attention_mask': encoding['attention_mask'].flatten(), | |
| 'labels': torch.tensor(label, dtype=torch.long) | |
| } | |
| def train_model(dataset_path, epochs, batch_size, learning_rate, gradient_accumulation_steps, progress=gr.Progress()): | |
| global MODEL, TOKENIZER | |
| if not TOKENIZER: | |
| return "❌ Tokenizer not loaded. Please install transformers library.", None, None | |
| if not os.path.exists(dataset_path): | |
| return "❌ Dataset file not found. Please generate a dataset first.", None, None | |
| # ⭐️ Important Check | |
| if batch_size < 1: | |
| return "❌ Error: Batch Size must be at least 1.", None, None | |
| if gradient_accumulation_steps < 1: | |
| return "❌ Error: Gradient Accumulation Steps must be at least 1.", None, None | |
| try: | |
| # Load dataset | |
| progress(0.1, desc="Loading dataset...") | |
| data = [] | |
| with open(dataset_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| data.append(json.loads(line)) | |
| df = pd.DataFrame(data) | |
| texts = df['text'].tolist() | |
| labels = [LABEL_TO_ID[label] for label in df['era_label'].tolist()] | |
| # Split data | |
| progress(0.2, desc="Splitting data...") | |
| X_train, X_temp, y_train, y_temp = train_test_split(texts, labels, test_size=0.2, random_state=42, stratify=labels) | |
| X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp) | |
| train_dataset = LongformerDataset(X_train, y_train, TOKENIZER) | |
| val_dataset = LongformerDataset(X_val, y_val, TOKENIZER) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size) | |
| # Initialize model | |
| progress(0.3, desc="Initializing model...") | |
| MODEL = AutoModelForSequenceClassification.from_pretrained( | |
| "allenai/longformer-base-4096", | |
| num_labels=len(LABEL_TO_ID) | |
| ) | |
| MODEL.to(DEVICE) | |
| optimizer = AdamW(MODEL.parameters(), lr=learning_rate) | |
| # Recalculate total steps for scheduler to account for accumulation | |
| total_steps = (len(train_loader) // gradient_accumulation_steps) * epochs | |
| scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) | |
| # Training loop | |
| train_losses = [] | |
| val_accuracies = [] | |
| step_count = 0 # Tracks steps for gradient accumulation | |
| for epoch in range(epochs): | |
| MODEL.train() | |
| total_loss = 0 | |
| for batch_idx, batch in enumerate(train_loader): | |
| # Calculate progress based on total batches, not steps | |
| progress_val = (0.3 + (epoch / epochs) * 0.6) + ((batch_idx / len(train_loader)) / epochs * 0.6) | |
| progress(progress_val, desc=f"Training Epoch {epoch+1}/{epochs} (Batch {batch_idx+1}/{len(train_loader)})") | |
| input_ids = batch['input_ids'].to(DEVICE) | |
| attention_mask = batch['attention_mask'].to(DEVICE) | |
| labels = batch['labels'].to(DEVICE) | |
| # Forward pass | |
| outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs.loss | |
| # Normalize loss by accumulation steps | |
| loss = loss / gradient_accumulation_steps | |
| # Backward pass | |
| loss.backward() | |
| total_loss += loss.item() * gradient_accumulation_steps # Scale back up for reporting | |
| step_count += 1 | |
| # ⭐️ FIX 3: Gradient Accumulation Step | |
| if step_count % gradient_accumulation_steps == 0 or batch_idx == len(train_loader) - 1: | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| # Validation (only runs once per epoch) | |
| MODEL.eval() | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| input_ids = batch['input_ids'].to(DEVICE) | |
| attention_mask = batch['attention_mask'].to(DEVICE) | |
| labels = batch['labels'].to(DEVICE) | |
| outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask) | |
| predictions = torch.argmax(outputs.logits, dim=1) | |
| correct += (predictions == labels).sum().item() | |
| total += labels.size(0) | |
| avg_loss = total_loss / len(train_loader) | |
| val_acc = correct / total | |
| train_losses.append(avg_loss) | |
| val_accuracies.append(val_acc) | |
| print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}") | |
| # Save model | |
| progress(0.95, desc="Saving model...") | |
| timestamp = int(datetime.now().timestamp()) | |
| model_path = os.path.join(MODEL_DIR, f"longformer_era_classifier_{timestamp}") | |
| MODEL.save_pretrained(model_path) | |
| TOKENIZER.save_pretrained(model_path) | |
| # Create metrics dataframe | |
| metrics_df = pd.DataFrame({ | |
| "Epoch": list(range(1, epochs + 1)), | |
| "Training Loss": train_losses, | |
| "Validation Accuracy": [f"{acc:.4f}" for acc in val_accuracies] | |
| }) | |
| summary = f"✅ Training Complete!\nFinal Val Acc: {val_accuracies[-1]:.4f}\nModel saved to: {model_path}" | |
| return summary, metrics_df, model_path | |
| except RuntimeError as e: | |
| if 'out of memory' in str(e): | |
| return f"❌ Training error: CUDA Out Of Memory. Try reducing the 'Batch Size' slider to 1, or increase 'Gradient Accumulation Steps'. Error: {str(e)}", None, None | |
| return f"❌ Training error: {str(e)}", None, None | |
| except Exception as e: | |
| return f"❌ Training error: {str(e)}", None, None | |
| # ============================================================================ | |
| # TAB 3: TESTING | |
| # ============================================================================ | |
| def load_trained_model(model_path): | |
| global MODEL, TOKENIZER | |
| try: | |
| TOKENIZER = AutoTokenizer.from_pretrained(model_path) | |
| MODEL = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| MODEL.to(DEVICE) | |
| MODEL.eval() | |
| return f"✅ Model loaded successfully from {model_path}" | |
| except Exception as e: | |
| return f"❌ Error loading model: {str(e)}" | |
| def predict_era(text, model_path): | |
| global MODEL, TOKENIZER | |
| if not MODEL or not TOKENIZER: | |
| if model_path and os.path.exists(model_path): | |
| load_result = load_trained_model(model_path) | |
| if "Error" in load_result: | |
| return load_result, None | |
| else: | |
| return "❌ No model loaded. Please train a model first or provide a valid model path.", None | |
| try: | |
| encoding = TOKENIZER( | |
| text, | |
| add_special_tokens=True, | |
| max_length=4096, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].to(DEVICE) | |
| attention_mask = encoding['attention_mask'].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1)[0] | |
| predicted_class = torch.argmax(probabilities).item() | |
| top_3_probs, top_3_indices = torch.topk(probabilities, 3) | |
| results = [] | |
| for idx, prob in zip(top_3_indices, top_3_probs): | |
| era_label = ID_TO_LABEL[idx.item()] | |
| confidence = prob.item() * 100 | |
| results.append({ | |
| "Era": era_label, | |
| "Confidence": f"{confidence:.2f}%" | |
| }) | |
| predicted_era = ID_TO_LABEL[predicted_class] | |
| result_text = f"🎯 **Predicted Era:** {predicted_era}\n\n**Confidence:** {probabilities[predicted_class].item()*100:.2f}%" | |
| return result_text, pd.DataFrame(results) | |
| except Exception as e: | |
| return f"❌ Prediction error: {str(e)}", None | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| with gr.Blocks(title="Complete ML Pipeline") as demo: | |
| gr.Markdown("# 📚 Complete ML Pipeline: Dataset Generation, Training & Testing") | |
| with gr.Tabs(): | |
| # TAB 1: Dataset Generation | |
| with gr.Tab("📊 Dataset Generation"): | |
| gr.Markdown("## Generate Historical Text Dataset") | |
| gr.Markdown("Dataset generation logic is now stabilized for network timeouts.") | |
| with gr.Row(): | |
| dataset_slider = gr.Slider(10, 100, step=10, value=50, label="Total Books to Collect") | |
| generate_btn = gr.Button("🚀 Generate Dataset", variant="primary", size="lg") | |
| dataset_download = gr.File(label="📥 Download Dataset ZIP") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### General Summary") | |
| gen_stats = gr.Dataframe() | |
| with gr.Column(): | |
| gr.Markdown("### Class Balance Check") | |
| split_stats = gr.Dataframe() | |
| generate_btn.click( | |
| generate_dataset, | |
| inputs=[dataset_slider], | |
| outputs=[dataset_download, gen_stats, split_stats] | |
| ) | |
| # TAB 2: Training | |
| with gr.Tab("🎓 Model Training"): | |
| gr.Markdown("## Train Longformer Era Classifier") | |
| gr.Markdown(""" | |
| **TRAINING FIX:** Training may stall due to CUDA Out Of Memory errors. | |
| If training stalls, try setting **Batch Size to 1** and **Gradient Accumulation Steps to 4 or higher**. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| train_dataset_path = gr.Textbox( | |
| label="Dataset Path", | |
| value=os.path.join(DATASET_DIR, "longformer_dataset.jsonl"), | |
| placeholder="Path to dataset JSONL file" | |
| ) | |
| train_epochs = gr.Slider(1, 10, step=1, value=3, label="Epochs") | |
| # ⭐️ FIX 3: Set default batch size to 1 for better memory management | |
| train_batch = gr.Slider(1, 8, step=1, value=1, label="Batch Size (Memory Control)") | |
| # ⭐️ FIX 3: Added Gradient Accumulation slider | |
| train_accum = gr.Slider(1, 16, step=1, value=4, label="Gradient Accumulation Steps (Effective Batch Size)") | |
| train_lr = gr.Number(value=2e-5, label="Learning Rate") | |
| train_btn = gr.Button("🏋️ Start Training", variant="primary", size="lg") | |
| with gr.Column(): | |
| train_output = gr.Textbox(label="Training Status", lines=8) | |
| train_metrics = gr.Dataframe(label="Training Metrics") | |
| model_path_output = gr.Textbox(label="Saved Model Path") | |
| train_btn.click( | |
| train_model, | |
| inputs=[train_dataset_path, train_epochs, train_batch, train_lr, train_accum], | |
| outputs=[train_output, train_metrics, model_path_output] | |
| ) | |
| # TAB 3: Testing | |
| with gr.Tab("🧪 Model Testing"): | |
| gr.Markdown("## Test Era Classification") | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_model_path = gr.Textbox( | |
| label="Model Path (optional - uses last trained model if empty)", | |
| placeholder="trained_models/longformer_era_classifier_..." | |
| ) | |
| test_input = gr.Textbox( | |
| label="Input Text", | |
| lines=10, | |
| placeholder="Paste historical text here...\n\nExample: 'When that Aprille with his shoures soote, The droghte of Marche hath perced to the roote...'" | |
| ) | |
| test_btn = gr.Button("🔍 Predict Era", variant="primary", size="lg") | |
| with gr.Column(): | |
| test_result = gr.Markdown(label="Prediction Result") | |
| test_probabilities = gr.Dataframe(label="Top 3 Predictions") | |
| # Sample texts | |
| gr.Markdown("### 📝 Try Sample Texts") | |
| with gr.Row(): | |
| sample1 = gr.Button("Medieval Sample") | |
| sample2 = gr.Button("Victorian Sample") | |
| sample3 = gr.Button("Contemporary Sample") | |
| def load_medieval(): | |
| return "Hwæt! We Gardena in geardagum, þeodcyninga, þrym gefrunon, hu ða æþelingas ellen fremedon." | |
| def load_victorian(): | |
| return "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife." | |
| def load_contemporary(): | |
| return "The internet has fundamentally transformed how we communicate, work, and access information in the digital age." | |
| sample1.click(load_medieval, outputs=[test_input]) | |
| sample2.click(load_victorian, outputs=[test_input]) | |
| sample3.click(load_contemporary, outputs=[test_input]) | |
| test_btn.click( | |
| predict_era, | |
| inputs=[test_input, test_model_path], | |
| outputs=[test_result, test_probabilities] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown(f"**Device:** {DEVICE} | **Status:** {'✅ Ready' if TOKENIZER else '⚠️ Transformers not installed'}") | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |