book-scraper / app.py
wuhp's picture
Update app.py
59ec29c verified
import gradio as gr
import pandas as pd
import requests
import internetarchive
from datetime import datetime
import re
import os
import shutil
import time
import random
import json
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import nest_asyncio
import sys
# --- SYSTEM FIXES ---
try:
nest_asyncio.apply()
except Exception as e:
print(f"Warning: Could not apply nest_asyncio: {e}")
# --- CONFIGURATION ---
DATASET_DIR = "dataset_ml_final_v2"
BOOKS_DIR = os.path.join(DATASET_DIR, "books")
MODEL_DIR = "trained_models"
os.makedirs(MODEL_DIR, exist_ok=True)
# --- TOKENIZER & MODEL ---
TOKENIZER = None
MODEL = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, logging
from torch.optim import AdamW
logging.set_verbosity_error()
print("Attempting to load Longformer Tokenizer...")
TOKENIZER = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
print("✅ Tokenizer loaded successfully.")
except Exception as e:
print(f"⚠️ Tokenizer loading error: {e}")
AdamW = None
# --- ERAS (10 Distinct Periods) ---
ERAS = [
(500, 1200, "0_Medieval", "Medieval OR Middle Ages OR Latin manuscripts"),
(1200, 1470, "1_Late_Medieval", "Middle English OR Old French OR Troubadour"),
(1470, 1650, "2_Early_Modern_Renaissance", "Renaissance OR Early Modern"),
(1650, 1800, "3_Enlightenment_Classical", "Enlightenment OR Classical literature"),
(1800, 1850, "4_Romantic", "Romanticism OR Romantic period"),
(1850, 1920, "5_Industrial_Victorian", "Victorian OR Industrial Age"),
(1920, 1945, "6_Modernist", "Modernism OR Avant-garde"),
(1945, 1960, "7_Postwar_Early_Modern", "Postwar OR Early Cold War"),
(1960, 1990, "8_Late_20th_Century", "Late 20th Century OR Postmodern"),
(1990, 2024, "9_Contemporary_Information_Age", "Contemporary OR Digital era")
]
ERA_LABELS = [era[2] for era in ERAS]
LABEL_TO_ID = {label: idx for idx, label in enumerate(ERA_LABELS)}
ID_TO_LABEL = {idx: label for idx, label in enumerate(ERA_LABELS)}
# --- RESCUE KEYWORDS ---
RESCUE_KEYWORDS = {
"0_Medieval": [
"Beowulf", "Bede", "Anglo Saxon Chronicle", "Cynewulf", "Caedmon",
"Old English Homilies", "Aelfric", "Boethius", "Alfred the Great",
"Venerable Bede", "Old English", "Anglo-Saxon poetry"
],
"1_Late_Medieval": [
"Chaucer", "Canterbury Tales", "Piers Plowman", "Langland",
"Gower", "Malory", "Morte d'Arthur", "Wycliffe",
"Julian Norwich", "Margery Kempe", "Froissart", "Everyman",
"Gawain", "Pearl Poet", "Lydgate", "Troilus Criseyde",
"Book Duchess", "Parliament Fowls", "Legend Good Women",
"Christine Pizan", "Romance Rose", "Confessio Amantis",
"mystery plays", "miracle plays", "morality plays",
"Middle English", "medieval romance", "medieval literature",
"14th century literature", "15th century literature",
"medieval poetry", "medieval drama", "Arthurian legend",
"Chivalric romance", "Courtly love", "medieval manuscript",
"Caxton", "medieval texts", "English medieval", "French medieval"
]
}
LATE_MEDIEVAL_COLLECTIONS = [
"gutenberg", "opensource", "medievaltexts", "earlyenglishbooksonline",
"englishliterature", "medievalmanuscripts", "britishlibrary"
]
TOPICS = [
"History", "Philosophy", "Science", "Mathematics", "Medicine", "Astronomy",
"Physics", "Chemistry", "Biology", "Fiction", "Poetry", "Drama", "Mythology",
"Folklore", "Religion", "Theology", "Biography", "Politics", "Economics", "Law",
"Sociology", "Technology", "Engineering", "Travel", "War", "Military", "Art",
"Psychology", "Anthropology", "Literature", "Essays", "Memoirs", "Education"
]
# ============================================================================
# TAB 1: DATASET GENERATION
# ============================================================================
def setup_dirs():
if os.path.exists(DATASET_DIR):
try: shutil.rmtree(DATASET_DIR)
except: pass
os.makedirs(BOOKS_DIR, exist_ok=True)
def chunk_text_robust(text):
MAX_TOKENS = 3500
STRIDE = 500
MAX_CHUNKS_PER_BOOK = 40
chunks = []
if TOKENIZER:
try:
tokens = TOKENIZER.encode(text, add_special_tokens=False)
i = 0
while i < len(tokens) and len(chunks) < MAX_CHUNKS_PER_BOOK:
chunk_ids = tokens[i : i + MAX_TOKENS]
chunk_str = TOKENIZER.decode(chunk_ids, skip_special_tokens=True)
chunks.append(chunk_str)
i += (MAX_TOKENS - STRIDE)
return chunks
except: pass
WORDS_PER_CHUNK = 2700
WORD_STRIDE = 400
words = text.split()
i = 0
while i < len(words) and len(chunks) < MAX_CHUNKS_PER_BOOK:
chunk_words = words[i : i + WORDS_PER_CHUNK]
chunk_str = " ".join(chunk_words)
if len(chunk_str) > 300:
chunks.append(chunk_str)
i += (WORDS_PER_CHUNK - WORD_STRIDE)
return chunks
def clean_text_content(text):
markers = [("*** START OF", "*** END OF")]
for start_m, end_m in markers:
s = text.find(start_m)
e = text.find(end_m)
if s != -1 and e != -1:
text = text[s+len(start_m):e]
break
return text.strip()
def download_book(identifier, title, year, era_label, min_char_limit=5000):
urls = [
f"https://archive.org/download/{identifier}/{identifier}_djvu.txt",
f"https://archive.org/download/{identifier}/{identifier}.txt"
]
content = ""
for url in urls:
try:
r = requests.get(url, timeout=15)
if r.status_code == 200:
content = r.text
break
except: pass
content = clean_text_content(content)
if len(content) < min_char_limit:
return None
safe_title = re.sub(r'[^a-zA-Z0-9]', '_', title)[:40]
filename = f"{year}_{era_label}_{safe_title}_{identifier}.txt"
with open(os.path.join(BOOKS_DIR, filename), "w", encoding="utf-8") as f:
f.write(content)
return {
"title": title, "year": int(year), "era_label": era_label,
"filename": filename, "char_count": len(content), "source": "Internet Archive"
}
def generate_dataset(total_books_needed, progress=gr.Progress()):
setup_dirs()
records = []
books_per_era = max(1, int(total_books_needed / len(ERAS)))
for start_year, end_year, era_label, search_hint in ERAS:
collected = 0
attempts = 0
era_topics = TOPICS.copy()
random.shuffle(era_topics)
rescue_list = RESCUE_KEYWORDS.get(era_label, [])
is_hard_era = len(rescue_list) > 0
min_chars = 1000 if is_hard_era else 5000
max_attempts = 80 if era_label == "1_Late_Medieval" else (50 if is_hard_era else 20)
rescue_threshold = 0 if era_label == "1_Late_Medieval" else 3
progress(0, desc=f"Scraping Era: {era_label}")
print(f"\n{'='*60}")
print(f"Starting Era: {era_label} (Target: {books_per_era} books)")
print(f"{'='*60}")
while collected < books_per_era and attempts < max_attempts:
attempts += 1
using_rescue = False
if is_hard_era and attempts > rescue_threshold:
using_rescue = True
kw = random.choice(rescue_list)
if era_label == "1_Late_Medieval":
query_type = attempts % 6
if query_type == 0:
query = f"title:({kw}) AND mediatype:texts"
elif query_type == 1:
query = f"({kw}) AND mediatype:texts AND language:eng"
elif query_type == 2:
query = f"subject:({kw}) AND mediatype:texts"
elif query_type == 3:
col = random.choice(LATE_MEDIEVAL_COLLECTIONS)
query = f"({kw}) AND collection:({col}) AND mediatype:texts"
elif query_type == 4:
query = f"({kw}) AND date:[1200 TO 1900] AND mediatype:texts AND language:eng"
else:
query = f"{kw} mediatype:texts"
else:
if attempts % 3 == 0:
query = f"title:({kw}) AND mediatype:texts"
elif attempts % 3 == 1:
query = f"({kw}) AND mediatype:texts AND language:eng"
else:
query = f"subject:({kw}) AND mediatype:texts"
print(f" > 🛡️ Rescue Search #{attempts} ({era_label}): {kw}")
else:
if not era_topics:
era_topics = TOPICS.copy()
random.shuffle(era_topics)
topic = era_topics.pop()
query = f"(subject:{topic} OR {search_hint}) AND date:[{start_year} TO {end_year}] AND mediatype:texts AND language:eng"
if end_year > 1928:
query += " AND (licenseurl:* OR rights:creative commons OR collection:opensourcemedia)"
print(f" > Standard Search #{attempts}: {topic}")
try:
search_generator = internetarchive.search_items(
query,
sorts=['downloads desc'],
fields=['identifier', 'title', 'date', 'year']
)
# ⭐️ FIX 1: Pre-fetch a batch of results to close the search connection quickly
search_results_batch = []
max_check_per_query = (50 if era_label == "1_Late_Medieval" else (30 if is_hard_era else 10))
for i, item in enumerate(search_generator):
search_results_batch.append(item)
if i >= max_check_per_query: break
results_found = len(search_results_batch)
for res in search_results_batch:
if collected >= books_per_era: break
id_ = res.get('identifier')
raw_date = res.get('date') or res.get('year')
year = str(raw_date)[:4] if raw_date else "0000"
if not year.isdigit(): year = "0000"
if not using_rescue:
if not (start_year <= int(year) <= end_year):
continue
if any(r['filename'].endswith(f"{id_}.txt") for r in records):
continue
rec = download_book(id_, res.get('title', 'Unknown'), year, era_label, min_char_limit=min_chars)
if rec:
rec['topic'] = "Classic" if using_rescue else topic
records.append(rec)
collected += 1
print(f" ✅ Saved ({collected}/{books_per_era}): {rec['title']} ({year})")
if results_found == 0:
print(f" ⚠️ No results found for this query")
except Exception as e:
print(f" ❌ Search error: {e}")
time.sleep(1)
print(f"Completed {era_label}: {collected}/{books_per_era} books collected")
if era_label == "1_Late_Medieval" and collected < books_per_era * 0.3:
print(f"\n⚠️ EMERGENCY FALLBACK MODE for {era_label}")
fallback_attempts = 0
fallback_terms = [
"medieval english", "middle english", "chaucer OR malory OR gower",
"14th century OR 15th century", "medieval literature english",
"arthurian romance", "medieval poetry english"
]
while collected < books_per_era and fallback_attempts < len(fallback_terms):
term = fallback_terms[fallback_attempts]
fallback_attempts += 1
query = f"({term}) AND mediatype:texts"
print(f" > 🚨 Fallback #{fallback_attempts}: {term}")
try:
search_generator = internetarchive.search_items(query, sorts=['downloads desc'], fields=['identifier', 'title', 'date', 'year'])
fallback_batch = []
for i, item in enumerate(search_generator):
fallback_batch.append(item)
if i >= 100: break
checked = 0
for res in fallback_batch:
if collected >= books_per_era:
break
checked += 1
id_ = res.get('identifier')
if any(r['filename'].endswith(f"{id_}.txt") for r in records):
continue
raw_date = res.get('date') or res.get('year')
year = str(raw_date)[:4] if raw_date else "0000"
if not year.isdigit(): year = "0000"
rec = download_book(id_, res.get('title', 'Unknown'), year, era_label, min_char_limit=min_chars)
if rec:
rec['topic'] = "Medieval"
records.append(rec)
collected += 1
print(f" ✅ FALLBACK Success ({collected}/{books_per_era}): {rec['title']}")
except Exception as e:
print(f" ❌ Fallback error: {e}")
time.sleep(1)
if not records: return None, pd.DataFrame(), pd.DataFrame()
print("\n" + "="*60)
print("Starting Robust Chunking...")
print("="*60)
progress(0.9, desc="Chunking Text...")
longformer_rows = []
for r in records:
file_path = os.path.join(BOOKS_DIR, r["filename"])
try:
with open(file_path, "r", encoding="utf-8") as f:
raw_text = f.read()
chunks = chunk_text_robust(raw_text)
for idx, chunk in enumerate(chunks):
longformer_rows.append({
"text": chunk,
"era_label": r["era_label"],
"year": r["year"],
"chunk_id": idx
})
print(f" ✅ Chunked {r['title']}: {len(chunks)} chunks")
except Exception as e:
print(f" ❌ Error processing {r['filename']}: {e}")
df_rows = pd.DataFrame(longformer_rows)
if not df_rows.empty:
split_stats = df_rows['era_label'].value_counts().reset_index()
split_stats.columns = ['Era Label', 'Total Chunks']
split_stats['Est. Train (80%)'] = (split_stats['Total Chunks'] * 0.8).astype(int)
split_stats['Est. Val (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int)
split_stats['Est. Test (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int)
split_stats['Status'] = split_stats['Est. Val (10%)'].apply(lambda x: "⚠️ LOW DATA" if x < 5 else "✅ OK")
else:
split_stats = pd.DataFrame()
total_chunks = len(longformer_rows)
avg_chunks = total_chunks / len(records) if records else 0
general_stats_df = pd.DataFrame({
"Metric": ["Total Books", "Total Training Examples", "Avg Examples/Book"],
"Value": [len(records), total_chunks, f"{avg_chunks:.1f}"]
})
pd.DataFrame(records).to_csv(os.path.join(DATASET_DIR, "metadata.csv"), index=False)
jsonl_path = os.path.join(DATASET_DIR, "longformer_dataset.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
for row in longformer_rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
timestamp = int(datetime.now().timestamp())
zip_filename = f"Analyzed_ML_Dataset_{timestamp}"
shutil.make_archive(zip_filename, 'zip', DATASET_DIR)
print("\n" + "="*60)
print("Dataset Generation Complete!")
print("="*60)
return f"{zip_filename}.zip", general_stats_df, split_stats
# ============================================================================
# TAB 2: TRAINING
# ============================================================================
class LongformerDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=4096):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
encoding = self.tokenizer(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def train_model(dataset_path, epochs, batch_size, learning_rate, gradient_accumulation_steps, progress=gr.Progress()):
global MODEL, TOKENIZER
if not TOKENIZER:
return "❌ Tokenizer not loaded. Please install transformers library.", None, None
if not os.path.exists(dataset_path):
return "❌ Dataset file not found. Please generate a dataset first.", None, None
# ⭐️ Important Check
if batch_size < 1:
return "❌ Error: Batch Size must be at least 1.", None, None
if gradient_accumulation_steps < 1:
return "❌ Error: Gradient Accumulation Steps must be at least 1.", None, None
try:
# Load dataset
progress(0.1, desc="Loading dataset...")
data = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
df = pd.DataFrame(data)
texts = df['text'].tolist()
labels = [LABEL_TO_ID[label] for label in df['era_label'].tolist()]
# Split data
progress(0.2, desc="Splitting data...")
X_train, X_temp, y_train, y_temp = train_test_split(texts, labels, test_size=0.2, random_state=42, stratify=labels)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)
train_dataset = LongformerDataset(X_train, y_train, TOKENIZER)
val_dataset = LongformerDataset(X_val, y_val, TOKENIZER)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
# Initialize model
progress(0.3, desc="Initializing model...")
MODEL = AutoModelForSequenceClassification.from_pretrained(
"allenai/longformer-base-4096",
num_labels=len(LABEL_TO_ID)
)
MODEL.to(DEVICE)
optimizer = AdamW(MODEL.parameters(), lr=learning_rate)
# Recalculate total steps for scheduler to account for accumulation
total_steps = (len(train_loader) // gradient_accumulation_steps) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
# Training loop
train_losses = []
val_accuracies = []
step_count = 0 # Tracks steps for gradient accumulation
for epoch in range(epochs):
MODEL.train()
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
# Calculate progress based on total batches, not steps
progress_val = (0.3 + (epoch / epochs) * 0.6) + ((batch_idx / len(train_loader)) / epochs * 0.6)
progress(progress_val, desc=f"Training Epoch {epoch+1}/{epochs} (Batch {batch_idx+1}/{len(train_loader)})")
input_ids = batch['input_ids'].to(DEVICE)
attention_mask = batch['attention_mask'].to(DEVICE)
labels = batch['labels'].to(DEVICE)
# Forward pass
outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
# Normalize loss by accumulation steps
loss = loss / gradient_accumulation_steps
# Backward pass
loss.backward()
total_loss += loss.item() * gradient_accumulation_steps # Scale back up for reporting
step_count += 1
# ⭐️ FIX 3: Gradient Accumulation Step
if step_count % gradient_accumulation_steps == 0 or batch_idx == len(train_loader) - 1:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Validation (only runs once per epoch)
MODEL.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(DEVICE)
attention_mask = batch['attention_mask'].to(DEVICE)
labels = batch['labels'].to(DEVICE)
outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(train_loader)
val_acc = correct / total
train_losses.append(avg_loss)
val_accuracies.append(val_acc)
print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}")
# Save model
progress(0.95, desc="Saving model...")
timestamp = int(datetime.now().timestamp())
model_path = os.path.join(MODEL_DIR, f"longformer_era_classifier_{timestamp}")
MODEL.save_pretrained(model_path)
TOKENIZER.save_pretrained(model_path)
# Create metrics dataframe
metrics_df = pd.DataFrame({
"Epoch": list(range(1, epochs + 1)),
"Training Loss": train_losses,
"Validation Accuracy": [f"{acc:.4f}" for acc in val_accuracies]
})
summary = f"✅ Training Complete!\nFinal Val Acc: {val_accuracies[-1]:.4f}\nModel saved to: {model_path}"
return summary, metrics_df, model_path
except RuntimeError as e:
if 'out of memory' in str(e):
return f"❌ Training error: CUDA Out Of Memory. Try reducing the 'Batch Size' slider to 1, or increase 'Gradient Accumulation Steps'. Error: {str(e)}", None, None
return f"❌ Training error: {str(e)}", None, None
except Exception as e:
return f"❌ Training error: {str(e)}", None, None
# ============================================================================
# TAB 3: TESTING
# ============================================================================
def load_trained_model(model_path):
global MODEL, TOKENIZER
try:
TOKENIZER = AutoTokenizer.from_pretrained(model_path)
MODEL = AutoModelForSequenceClassification.from_pretrained(model_path)
MODEL.to(DEVICE)
MODEL.eval()
return f"✅ Model loaded successfully from {model_path}"
except Exception as e:
return f"❌ Error loading model: {str(e)}"
def predict_era(text, model_path):
global MODEL, TOKENIZER
if not MODEL or not TOKENIZER:
if model_path and os.path.exists(model_path):
load_result = load_trained_model(model_path)
if "Error" in load_result:
return load_result, None
else:
return "❌ No model loaded. Please train a model first or provide a valid model path.", None
try:
encoding = TOKENIZER(
text,
add_special_tokens=True,
max_length=4096,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(DEVICE)
attention_mask = encoding['attention_mask'].to(DEVICE)
with torch.no_grad():
outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0]
predicted_class = torch.argmax(probabilities).item()
top_3_probs, top_3_indices = torch.topk(probabilities, 3)
results = []
for idx, prob in zip(top_3_indices, top_3_probs):
era_label = ID_TO_LABEL[idx.item()]
confidence = prob.item() * 100
results.append({
"Era": era_label,
"Confidence": f"{confidence:.2f}%"
})
predicted_era = ID_TO_LABEL[predicted_class]
result_text = f"🎯 **Predicted Era:** {predicted_era}\n\n**Confidence:** {probabilities[predicted_class].item()*100:.2f}%"
return result_text, pd.DataFrame(results)
except Exception as e:
return f"❌ Prediction error: {str(e)}", None
# ============================================================================
# GRADIO UI
# ============================================================================
with gr.Blocks(title="Complete ML Pipeline") as demo:
gr.Markdown("# 📚 Complete ML Pipeline: Dataset Generation, Training & Testing")
with gr.Tabs():
# TAB 1: Dataset Generation
with gr.Tab("📊 Dataset Generation"):
gr.Markdown("## Generate Historical Text Dataset")
gr.Markdown("Dataset generation logic is now stabilized for network timeouts.")
with gr.Row():
dataset_slider = gr.Slider(10, 100, step=10, value=50, label="Total Books to Collect")
generate_btn = gr.Button("🚀 Generate Dataset", variant="primary", size="lg")
dataset_download = gr.File(label="📥 Download Dataset ZIP")
with gr.Row():
with gr.Column():
gr.Markdown("### General Summary")
gen_stats = gr.Dataframe()
with gr.Column():
gr.Markdown("### Class Balance Check")
split_stats = gr.Dataframe()
generate_btn.click(
generate_dataset,
inputs=[dataset_slider],
outputs=[dataset_download, gen_stats, split_stats]
)
# TAB 2: Training
with gr.Tab("🎓 Model Training"):
gr.Markdown("## Train Longformer Era Classifier")
gr.Markdown("""
**TRAINING FIX:** Training may stall due to CUDA Out Of Memory errors.
If training stalls, try setting **Batch Size to 1** and **Gradient Accumulation Steps to 4 or higher**.
""")
with gr.Row():
with gr.Column():
train_dataset_path = gr.Textbox(
label="Dataset Path",
value=os.path.join(DATASET_DIR, "longformer_dataset.jsonl"),
placeholder="Path to dataset JSONL file"
)
train_epochs = gr.Slider(1, 10, step=1, value=3, label="Epochs")
# ⭐️ FIX 3: Set default batch size to 1 for better memory management
train_batch = gr.Slider(1, 8, step=1, value=1, label="Batch Size (Memory Control)")
# ⭐️ FIX 3: Added Gradient Accumulation slider
train_accum = gr.Slider(1, 16, step=1, value=4, label="Gradient Accumulation Steps (Effective Batch Size)")
train_lr = gr.Number(value=2e-5, label="Learning Rate")
train_btn = gr.Button("🏋️ Start Training", variant="primary", size="lg")
with gr.Column():
train_output = gr.Textbox(label="Training Status", lines=8)
train_metrics = gr.Dataframe(label="Training Metrics")
model_path_output = gr.Textbox(label="Saved Model Path")
train_btn.click(
train_model,
inputs=[train_dataset_path, train_epochs, train_batch, train_lr, train_accum],
outputs=[train_output, train_metrics, model_path_output]
)
# TAB 3: Testing
with gr.Tab("🧪 Model Testing"):
gr.Markdown("## Test Era Classification")
with gr.Row():
with gr.Column():
test_model_path = gr.Textbox(
label="Model Path (optional - uses last trained model if empty)",
placeholder="trained_models/longformer_era_classifier_..."
)
test_input = gr.Textbox(
label="Input Text",
lines=10,
placeholder="Paste historical text here...\n\nExample: 'When that Aprille with his shoures soote, The droghte of Marche hath perced to the roote...'"
)
test_btn = gr.Button("🔍 Predict Era", variant="primary", size="lg")
with gr.Column():
test_result = gr.Markdown(label="Prediction Result")
test_probabilities = gr.Dataframe(label="Top 3 Predictions")
# Sample texts
gr.Markdown("### 📝 Try Sample Texts")
with gr.Row():
sample1 = gr.Button("Medieval Sample")
sample2 = gr.Button("Victorian Sample")
sample3 = gr.Button("Contemporary Sample")
def load_medieval():
return "Hwæt! We Gardena in geardagum, þeodcyninga, þrym gefrunon, hu ða æþelingas ellen fremedon."
def load_victorian():
return "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife."
def load_contemporary():
return "The internet has fundamentally transformed how we communicate, work, and access information in the digital age."
sample1.click(load_medieval, outputs=[test_input])
sample2.click(load_victorian, outputs=[test_input])
sample3.click(load_contemporary, outputs=[test_input])
test_btn.click(
predict_era,
inputs=[test_input, test_model_path],
outputs=[test_result, test_probabilities]
)
gr.Markdown("---")
gr.Markdown(f"**Device:** {DEVICE} | **Status:** {'✅ Ready' if TOKENIZER else '⚠️ Transformers not installed'}")
if __name__ == "__main__":
demo.launch(ssr_mode=False)