wuhp commited on
Commit
79769ad
·
verified ·
1 Parent(s): 59ec29c

Create gpu.py

Browse files
Files changed (1) hide show
  1. gpu.py +808 -0
gpu.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import requests
4
+ import internetarchive
5
+ from datetime import datetime
6
+ import re
7
+ import os
8
+ import shutil
9
+ import time
10
+ import random
11
+ import json
12
+ import torch
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from sklearn.model_selection import train_test_split
15
+ import numpy as np
16
+ import nest_asyncio
17
+ import sys
18
+
19
+ # --- SYSTEM FIXES ---
20
+ try:
21
+ nest_asyncio.apply()
22
+ except Exception as e:
23
+ print(f"Warning: Could not apply nest_asyncio: {e}")
24
+
25
+ # --- CONFIGURATION ---
26
+ DATASET_DIR = "dataset_ml_final_v2"
27
+ BOOKS_DIR = os.path.join(DATASET_DIR, "books")
28
+ MODEL_DIR = "trained_models"
29
+ os.makedirs(MODEL_DIR, exist_ok=True)
30
+
31
+ # --- TOKENIZER & MODEL ---
32
+ TOKENIZER = None
33
+ MODEL = None
34
+ # Check for CUDA support for GPU, otherwise use CPU
35
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ try:
38
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, logging
39
+ from torch.optim import AdamW
40
+
41
+ logging.set_verbosity_error()
42
+
43
+ print("Attempting to load Longformer Tokenizer...")
44
+ TOKENIZER = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
45
+ print("✅ Tokenizer loaded successfully.")
46
+ except Exception as e:
47
+ print(f"⚠️ Tokenizer loading error: {e}")
48
+ AdamW = None
49
+
50
+ # --- ERAS (10 Distinct Periods) ---
51
+ # DATASET FIX: Updated Search Hints for better boundary distinction
52
+ ERAS = [
53
+ (500, 1200, "0_Medieval", "Medieval OR Latin manuscript OR Anglo-Saxon prose"),
54
+ (1200, 1470, "1_Late_Medieval", "Chaucer OR Middle English OR morality play"),
55
+ (1470, 1650, "2_Early_Modern_Renaissance", "Shakespeare OR Bacon OR Protestant theology OR Early English Bible"),
56
+ (1650, 1800, "3_Enlightenment_Classical", "Pope couplets OR Swift satire OR Neoclassical OR reason science"),
57
+ (1800, 1850, "4_Romantic", "Byron OR Keats OR Shelley OR nature sublime emotion"),
58
+ (1850, 1920, "5_Industrial_Victorian", "Dickens OR Industrial Age OR Darwinism OR social novel"),
59
+ (1920, 1945, "6_Modernist", "Modernism OR stream of consciousness OR avant-garde fiction"),
60
+ (1945, 1960, "7_Postwar_Early_Modern", "Postwar OR Early Cold War OR existentialism"),
61
+ (1960, 1990, "8_Late_20th_Century", "Late 20th Century OR Postmodern OR Vietnam War"),
62
+ (1990, 2024, "9_Contemporary_Information_Age", "Contemporary OR Digital era OR internet culture")
63
+ ]
64
+
65
+ ERA_LABELS = [era[2] for era in ERAS]
66
+ LABEL_TO_ID = {label: idx for idx, label in enumerate(ERA_LABELS)}
67
+ ID_TO_LABEL = {idx: label for idx, label in enumerate(ERA_LABELS)}
68
+
69
+ # --- RESCUE KEYWORDS (Unchanged) ---
70
+ RESCUE_KEYWORDS = {
71
+ "0_Medieval": [
72
+ "Beowulf", "Bede", "Anglo Saxon Chronicle", "Cynewulf", "Caedmon",
73
+ "Old English Homilies", "Aelfric", "Boethius", "Alfred the Great",
74
+ "Venerable Bede", "Old English", "Anglo-Saxon poetry"
75
+ ],
76
+ "1_Late_Medieval": [
77
+ "Chaucer", "Canterbury Tales", "Piers Plowman", "Langland",
78
+ "Gower", "Malory", "Morte d'Arthur", "Wycliffe",
79
+ "Julian Norwich", "Margery Kempe", "Froissart", "Everyman",
80
+ "Gawain", "Pearl Poet", "Lydgate", "Troilus Criseyde",
81
+ "Book Duchess", "Parliament Fowls", "Legend Good Women",
82
+ "Christine Pizan", "Romance Rose", "Confessio Amantis",
83
+ "mystery plays", "miracle plays", "morality plays",
84
+ "Middle English", "medieval romance", "medieval literature",
85
+ "14th century literature", "15th century literature",
86
+ "medieval poetry", "medieval drama", "Arthurian legend",
87
+ "Chivalric romance", "Courtly love", "medieval manuscript",
88
+ "Caxton", "medieval texts", "English medieval", "French medieval"
89
+ ]
90
+ }
91
+
92
+ LATE_MEDIEVAL_COLLECTIONS = [
93
+ "gutenberg", "opensource", "medievaltexts", "earlyenglishbooksonline",
94
+ "englishliterature", "medievalmanuscripts", "britishlibrary"
95
+ ]
96
+
97
+ # DATASET FIX: Added more contemporary-friendly topics
98
+ TOPICS = [
99
+ "History", "Philosophy", "Science", "Mathematics", "Medicine", "Astronomy",
100
+ "Physics", "Chemistry", "Biology", "Fiction", "Poetry", "Drama", "Mythology",
101
+ "Folklore", "Religion", "Theology", "Biography", "Politics", "Economics", "Law",
102
+ "Sociology", "Technology", "Engineering", "Travel", "War", "Military", "Art",
103
+ "Psychology", "Anthropology", "Literature", "Essays", "Memoirs", "Education",
104
+ "Computer Programming", "Digital Culture", "Current Affairs"
105
+ ]
106
+
107
+ # ============================================================================
108
+ # TAB 1: DATASET GENERATION
109
+ # ============================================================================
110
+
111
+ def setup_dirs():
112
+ if os.path.exists(DATASET_DIR):
113
+ try: shutil.rmtree(DATASET_DIR)
114
+ except: pass
115
+ os.makedirs(BOOKS_DIR, exist_ok=True)
116
+
117
+ def text_quality_check(text):
118
+ """
119
+ A light-weight quality check to filter out poor scan or boilerplate text.
120
+ """
121
+ if len(text) < 3000: return False
122
+ alpha_count = sum(c.isalpha() for c in text)
123
+ total_count = len(text)
124
+ if alpha_count / (total_count + 1e-6) < 0.5: return False
125
+
126
+ start_snippet = text[:1000].lower()
127
+ boilerplate_indicators = ["table of contents", "chapter i", "preface", "index", "list of figures"]
128
+ if any(indicator in start_snippet for indicator in boilerplate_indicators):
129
+ if len(text) < 10000: return False
130
+
131
+ lines = text.split('\n')
132
+ from collections import Counter
133
+ line_counts = Counter(l.strip() for l in lines if l.strip())
134
+
135
+ if len(line_counts) < 50: return False
136
+
137
+ frequent_lines = sum(1 for count in line_counts.values() if count >= 3)
138
+ if frequent_lines / len(line_counts) > 0.1: return False
139
+
140
+ return True
141
+
142
+
143
+ def chunk_text_robust(text):
144
+ MAX_TOKENS = 3500
145
+ STRIDE = 500
146
+ MAX_CHUNKS_PER_BOOK = 40
147
+ chunks = []
148
+
149
+ if TOKENIZER:
150
+ try:
151
+ tokens = TOKENIZER.encode(text, add_special_tokens=False)
152
+ i = 0
153
+ while i < len(tokens) and len(chunks) < MAX_CHUNKS_PER_BOOK:
154
+ chunk_ids = tokens[i : i + MAX_TOKENS]
155
+ chunk_str = TOKENIZER.decode(chunk_ids, skip_special_tokens=True)
156
+ chunks.append(chunk_str)
157
+ i += (MAX_TOKENS - STRIDE)
158
+ return chunks
159
+ except: pass
160
+
161
+ WORDS_PER_CHUNK = 2700
162
+ WORD_STRIDE = 400
163
+ words = text.split()
164
+ i = 0
165
+ while i < len(words) and len(chunks) < MAX_CHUNKS_PER_BOOK:
166
+ chunk_words = words[i : i + WORDS_PER_CHUNK]
167
+ chunk_str = " ".join(chunk_words)
168
+ if len(chunk_str) > 300:
169
+ chunks.append(chunk_str)
170
+ i += (WORDS_PER_CHUNK - WORD_STRIDE)
171
+ return chunks
172
+
173
+ # ⭐️ FIX: Ensuring clean_text_content is defined before download_book
174
+ def clean_text_content(text):
175
+ markers = [("*** START OF", "*** END OF")]
176
+ for start_m, end_m in markers:
177
+ s = text.find(start_m)
178
+ e = text.find(end_m)
179
+ if s != -1 and e != -1:
180
+ text = text[s+len(start_m):e]
181
+ break
182
+ return text.strip()
183
+
184
+ # MODIFIED download_book to accept a bypass flag
185
+ def download_book(identifier, title, year, era_label, min_char_limit=5000, bypass_quality_check=False):
186
+ urls = [
187
+ f"https://archive.org/download/{identifier}/{identifier}_djvu.txt",
188
+ f"https://archive.org/download/{identifier}/{identifier}.txt"
189
+ ]
190
+ content = ""
191
+ for url in urls:
192
+ try:
193
+ r = requests.get(url, timeout=15)
194
+ if r.status_code == 200:
195
+ content = r.text
196
+ break
197
+ except: pass
198
+
199
+ content = clean_text_content(content) # <-- The line that was failing
200
+
201
+ if len(content) < min_char_limit:
202
+ return None
203
+
204
+ if not bypass_quality_check:
205
+ if not text_quality_check(content):
206
+ return None
207
+
208
+ safe_title = re.sub(r'[^a-zA-Z0-9]', '_', title)[:40]
209
+ filename = f"{year}_{era_label}_{safe_title}_{identifier}.txt"
210
+ with open(os.path.join(BOOKS_DIR, filename), "w", encoding="utf-8") as f:
211
+ f.write(content)
212
+
213
+ return {
214
+ "title": title, "year": int(year), "era_label": era_label,
215
+ "filename": filename, "char_count": len(content), "source": "Internet Archive"
216
+ }
217
+
218
+ def generate_dataset(total_books_needed, progress=gr.Progress()):
219
+ setup_dirs()
220
+ records = []
221
+
222
+ books_per_era = max(1, int(total_books_needed / len(ERAS)))
223
+
224
+ for start_year, end_year, era_label, search_hint in ERAS:
225
+ collected = 0
226
+ attempts = 0
227
+ era_topics = TOPICS.copy()
228
+ random.shuffle(era_topics)
229
+
230
+ rescue_list = RESCUE_KEYWORDS.get(era_label, [])
231
+ is_hard_era = len(rescue_list) > 0
232
+
233
+ min_chars = 5000
234
+ bypass_qc = False
235
+
236
+ # FIX: Specialized rules for Contemporary Era
237
+ if era_label == "9_Contemporary_Information_Age":
238
+ min_chars = 2000 # Lower character requirement
239
+ bypass_qc = True # Disable strict quality check
240
+ max_attempts = 40 # Increase max attempts for this hard era
241
+ elif is_hard_era:
242
+ min_chars = 1000
243
+ max_attempts = 50 if era_label == "1_Late_Medieval" else 20
244
+ else:
245
+ max_attempts = 20
246
+
247
+ rescue_threshold = 0 if era_label == "1_Late_Medieval" else 3
248
+
249
+ progress(0, desc=f"Scraping Era: {era_label}")
250
+ print(f"\n{'='*60}")
251
+ print(f"Starting Era: {era_label} (Target: {books_per_era} books | Min Chars: {min_chars})")
252
+ print(f"{'='*60}")
253
+
254
+ while collected < books_per_era and attempts < max_attempts:
255
+ attempts += 1
256
+ using_rescue = False
257
+
258
+ if is_hard_era and attempts > rescue_threshold:
259
+ using_rescue = True
260
+ kw = random.choice(rescue_list)
261
+
262
+ if era_label == "1_Late_Medieval":
263
+ query_type = attempts % 6
264
+ if query_type == 0: query = f"title:({kw}) AND mediatype:texts"
265
+ elif query_type == 1: query = f"({kw}) AND mediatype:texts AND language:eng"
266
+ elif query_type == 2: query = f"subject:({kw}) AND mediatype:texts"
267
+ elif query_type == 3:
268
+ col = random.choice(LATE_MEDIEVAL_COLLECTIONS)
269
+ query = f"({kw}) AND collection:({col}) AND mediatype:texts"
270
+ elif query_type == 4: query = f"({kw}) AND date:[1200 TO 1900] AND mediatype:texts AND language:eng"
271
+ else: query = f"{kw} mediatype:texts"
272
+ else:
273
+ if attempts % 3 == 0: query = f"title:({kw}) AND mediatype:texts"
274
+ elif attempts % 3 == 1: query = f"({kw}) AND mediatype:texts AND language:eng"
275
+ else: query = f"subject:({kw}) AND mediatype:texts"
276
+ print(f" > 🛡️ Rescue Search #{attempts} ({era_label}): {kw}")
277
+ else:
278
+ if not era_topics:
279
+ era_topics = TOPICS.copy()
280
+ random.shuffle(era_topics)
281
+ topic = era_topics.pop()
282
+ query = f"(subject:{topic} OR {search_hint}) AND date:[{start_year} TO {end_year}] AND mediatype:texts AND language:eng"
283
+ if end_year > 1928:
284
+ query += " AND (licenseurl:* OR rights:creative commons OR collection:opensourcemedia)"
285
+ print(f" > Standard Search #{attempts}: {topic} | Hint: {search_hint.split(' OR ')[0]}...")
286
+
287
+ try:
288
+ search_generator = internetarchive.search_items(
289
+ query,
290
+ sorts=['downloads desc'],
291
+ fields=['identifier', 'title', 'date', 'year']
292
+ )
293
+
294
+ search_results_batch = []
295
+ max_check_per_query = (50 if is_hard_era or era_label == "9_Contemporary_Information_Age" else 15)
296
+ for i, item in enumerate(search_generator):
297
+ search_results_batch.append(item)
298
+ if i >= max_check_per_query: break
299
+
300
+ results_found = len(search_results_batch)
301
+
302
+ for res in search_results_batch:
303
+ if collected >= books_per_era: break
304
+
305
+ id_ = res.get('identifier')
306
+ raw_date = res.get('date') or res.get('year')
307
+ year = str(raw_date)[:4] if raw_date else "0000"
308
+
309
+ if not year.isdigit(): year = "0000"
310
+
311
+ if not using_rescue:
312
+ if not (start_year <= int(year) <= end_year):
313
+ continue
314
+
315
+ if any(r['filename'].endswith(f"{id_}.txt") for r in records):
316
+ continue
317
+
318
+ rec = download_book(
319
+ id_, res.get('title', 'Unknown'), year, era_label,
320
+ min_char_limit=min_chars,
321
+ bypass_quality_check=bypass_qc
322
+ )
323
+
324
+ if rec:
325
+ rec['topic'] = "Classic" if using_rescue else topic
326
+ records.append(rec)
327
+ collected += 1
328
+ print(f" ✅ Saved ({collected}/{books_per_era}): {rec['title']} ({year}) | Chars: {rec['char_count']}")
329
+
330
+ if results_found == 0:
331
+ print(f" ⚠️ No results found for this query")
332
+
333
+ except Exception as e:
334
+ print(f" ❌ Search error: {e}")
335
+ time.sleep(1)
336
+
337
+ print(f"Completed {era_label}: {collected}/{books_per_era} books collected")
338
+
339
+ # ... (Fallback logic for Late Medieval remains) ...
340
+ if era_label == "1_Late_Medieval" and collected < books_per_era * 0.3:
341
+ print(f"\n⚠️ EMERGENCY FALLBACK MODE for {era_label}")
342
+ fallback_attempts = 0
343
+ fallback_terms = [
344
+ "medieval english", "middle english", "chaucer OR malory OR gower",
345
+ "14th century OR 15th century", "medieval literature english",
346
+ "arthurian romance", "medieval poetry english"
347
+ ]
348
+
349
+ while collected < books_per_era and fallback_attempts < len(fallback_terms):
350
+ term = fallback_terms[fallback_attempts]
351
+ fallback_attempts += 1
352
+ query = f"({term}) AND mediatype:texts"
353
+ print(f" > 🚨 Fallback #{fallback_attempts}: {term}")
354
+
355
+ try:
356
+ search_generator = internetarchive.search_items(query, sorts=['downloads desc'], fields=['identifier', 'title', 'date', 'year'])
357
+
358
+ fallback_batch = []
359
+ for i, item in enumerate(search_generator):
360
+ fallback_batch.append(item)
361
+ if i >= 100: break
362
+
363
+ checked = 0
364
+ for res in fallback_batch:
365
+ if collected >= books_per_era:
366
+ break
367
+ checked += 1
368
+
369
+ id_ = res.get('identifier')
370
+ if any(r['filename'].endswith(f"{id_}.txt") for r in records):
371
+ continue
372
+
373
+ raw_date = res.get('date') or res.get('year')
374
+ year = str(raw_date)[:4] if raw_date else "0000"
375
+ if not year.isdigit(): year = "0000"
376
+
377
+ rec = download_book(
378
+ id_, res.get('title', 'Unknown'), year, era_label,
379
+ min_char_limit=min_chars,
380
+ bypass_quality_check=bypass_qc
381
+ )
382
+ if rec:
383
+ rec['topic'] = "Medieval"
384
+ records.append(rec)
385
+ collected += 1
386
+ print(f" ✅ FALLBACK Success ({collected}/{books_per_era}): {rec['title']} | Chars: {rec['char_count']}")
387
+ except Exception as e:
388
+ print(f" ❌ Fallback error: {e}")
389
+ time.sleep(1)
390
+
391
+ if not records: return None, pd.DataFrame(), pd.DataFrame()
392
+
393
+ print("\n" + "="*60)
394
+ print("Starting Robust Chunking...")
395
+ print("="*60)
396
+ progress(0.9, desc="Chunking Text...")
397
+ longformer_rows = []
398
+
399
+ for r in records:
400
+ file_path = os.path.join(BOOKS_DIR, r["filename"])
401
+ try:
402
+ with open(file_path, "r", encoding="utf-8") as f:
403
+ raw_text = f.read()
404
+ chunks = chunk_text_robust(raw_text)
405
+ for idx, chunk in enumerate(chunks):
406
+ longformer_rows.append({
407
+ "text": chunk,
408
+ "era_label": r["era_label"],
409
+ "year": r["year"],
410
+ "chunk_id": idx
411
+ })
412
+ print(f" ✅ Chunked {r['title']}: {len(chunks)} chunks")
413
+ except Exception as e:
414
+ print(f" ❌ Error processing {r['filename']}: {e}")
415
+
416
+ df_rows = pd.DataFrame(longformer_rows)
417
+ if not df_rows.empty:
418
+ split_stats = df_rows['era_label'].value_counts().reset_index()
419
+ split_stats.columns = ['Era Label', 'Total Chunks']
420
+ split_stats['Est. Train (80%)'] = (split_stats['Total Chunks'] * 0.8).astype(int)
421
+ split_stats['Est. Val (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int)
422
+ split_stats['Est. Test (10%)'] = (split_stats['Total Chunks'] * 0.1).astype(int)
423
+ split_stats['Status'] = split_stats['Est. Val (10%)'].apply(lambda x: "⚠️ LOW DATA" if x < 5 else "✅ OK")
424
+ else:
425
+ split_stats = pd.DataFrame()
426
+
427
+ total_chunks = len(longformer_rows)
428
+ avg_chunks = total_chunks / len(records) if records else 0
429
+ general_stats_df = pd.DataFrame({
430
+ "Metric": ["Total Books", "Total Training Examples", "Avg Examples/Book"],
431
+ "Value": [len(records), total_chunks, f"{avg_chunks:.1f}"]
432
+ })
433
+
434
+ pd.DataFrame(records).to_csv(os.path.join(DATASET_DIR, "metadata.csv"), index=False)
435
+ jsonl_path = os.path.join(DATASET_DIR, "longformer_dataset.jsonl")
436
+ with open(jsonl_path, "w", encoding="utf-8") as f:
437
+ for row in longformer_rows:
438
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
439
+
440
+ timestamp = int(datetime.now().timestamp())
441
+ zip_filename = f"Analyzed_ML_Dataset_{timestamp}"
442
+ shutil.make_archive(zip_filename, 'zip', DATASET_DIR)
443
+
444
+ print("\n" + "="*60)
445
+ print("Dataset Generation Complete! READY FOR RETRAINING.")
446
+ print("="*60)
447
+
448
+ return f"{zip_filename}.zip", general_stats_df, split_stats
449
+
450
+ # ============================================================================
451
+ # TAB 2: TRAINING (No changes needed, already optimized for 4080 Super)
452
+ # ============================================================================
453
+
454
+ class LongformerDataset(Dataset):
455
+ def __init__(self, texts, labels, tokenizer, max_length=4096):
456
+ self.texts = texts
457
+ self.labels = labels
458
+ self.tokenizer = tokenizer
459
+ self.max_length = max_length
460
+
461
+ def __len__(self):
462
+ return len(self.texts)
463
+
464
+ def __getitem__(self, idx):
465
+ text = str(self.texts[idx])
466
+ label = self.labels[idx]
467
+
468
+ encoding = self.tokenizer(
469
+ text,
470
+ add_special_tokens=True,
471
+ max_length=self.max_length,
472
+ padding='max_length',
473
+ truncation=True,
474
+ return_tensors='pt'
475
+ )
476
+
477
+ return {
478
+ 'input_ids': encoding['input_ids'].flatten(),
479
+ 'attention_mask': encoding['attention_mask'].flatten(),
480
+ 'labels': torch.tensor(label, dtype=torch.long)
481
+ }
482
+
483
+ def train_model(dataset_path, epochs, batch_size, learning_rate, gradient_accumulation_steps, progress=gr.Progress()):
484
+ global MODEL, TOKENIZER
485
+
486
+ if not TOKENIZER:
487
+ return "❌ Tokenizer not loaded. Please install transformers library.", None, None
488
+
489
+ if not os.path.exists(dataset_path):
490
+ return "❌ Dataset file not found. Please generate a dataset first.", None, None
491
+
492
+ if batch_size < 1:
493
+ return "❌ Error: Batch Size must be at least 1.", None, None
494
+ if gradient_accumulation_steps < 1:
495
+ return "❌ Error: Gradient Accumulation Steps must be at least 1.", None, None
496
+
497
+ scaler = torch.cuda.amp.GradScaler() if DEVICE == "cuda" else None
498
+
499
+ try:
500
+ progress(0.1, desc="Loading dataset...")
501
+ data = []
502
+ with open(dataset_path, 'r', encoding='utf-8') as f:
503
+ for line in f:
504
+ data.append(json.loads(line))
505
+
506
+ df = pd.DataFrame(data)
507
+ texts = df['text'].tolist()
508
+ labels = [LABEL_TO_ID[label] for label in df['era_label'].tolist()]
509
+
510
+ progress(0.2, desc="Splitting data...")
511
+ X_train, X_temp, y_train, y_temp = train_test_split(texts, labels, test_size=0.2, random_state=42, stratify=labels)
512
+ X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)
513
+
514
+ train_dataset = LongformerDataset(X_train, y_train, TOKENIZER)
515
+ val_dataset = LongformerDataset(X_val, y_val, TOKENIZER)
516
+
517
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
518
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
519
+
520
+ progress(0.3, desc="Initializing model...")
521
+ MODEL = AutoModelForSequenceClassification.from_pretrained(
522
+ "allenai/longformer-base-4096",
523
+ num_labels=len(LABEL_TO_ID)
524
+ )
525
+ MODEL.to(DEVICE)
526
+
527
+ optimizer = AdamW(MODEL.parameters(), lr=learning_rate)
528
+
529
+ total_batches = len(train_loader)
530
+ total_training_steps = (total_batches // gradient_accumulation_steps) * epochs
531
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_training_steps)
532
+
533
+ train_losses = []
534
+ val_accuracies = []
535
+ step_count = 0
536
+
537
+ for epoch in range(epochs):
538
+ MODEL.train()
539
+ total_loss = 0
540
+
541
+ for batch_idx, batch in enumerate(train_loader):
542
+ progress_val = (0.3 + (epoch / epochs) * 0.6) + ((batch_idx / total_batches) / epochs * 0.6)
543
+ progress(progress_val, desc=f"Training Epoch {epoch+1}/{epochs} (Batch {batch_idx+1}/{total_batches})")
544
+
545
+ input_ids = batch['input_ids'].to(DEVICE)
546
+ attention_mask = batch['attention_mask'].to(DEVICE)
547
+ labels = batch['labels'].to(DEVICE)
548
+
549
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
550
+ outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
551
+ loss = outputs.loss
552
+ loss = loss / gradient_accumulation_steps
553
+
554
+ if scaler:
555
+ scaler.scale(loss).backward()
556
+ else:
557
+ loss.backward()
558
+
559
+ total_loss += loss.item() * gradient_accumulation_steps
560
+ step_count += 1
561
+
562
+ if step_count % gradient_accumulation_steps == 0 or batch_idx == total_batches - 1:
563
+ if scaler:
564
+ scaler.unscale_(optimizer)
565
+ torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
566
+ scaler.step(optimizer)
567
+ scaler.update()
568
+ else:
569
+ torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
570
+ optimizer.step()
571
+
572
+ scheduler.step()
573
+ optimizer.zero_grad()
574
+
575
+ MODEL.eval()
576
+ correct = 0
577
+ total = 0
578
+
579
+ with torch.no_grad():
580
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
581
+ for batch in val_loader:
582
+ input_ids = batch['input_ids'].to(DEVICE)
583
+ attention_mask = batch['attention_mask'].to(DEVICE)
584
+ labels = batch['labels'].to(DEVICE)
585
+
586
+ outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
587
+ predictions = torch.argmax(outputs.logits, dim=1)
588
+
589
+ correct += (predictions == labels).sum().item()
590
+ total += labels.size(0)
591
+
592
+ avg_loss = total_loss / total_batches
593
+ val_acc = correct / total
594
+
595
+ train_losses.append(avg_loss)
596
+ val_accuracies.append(val_acc)
597
+
598
+ print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Val Acc: {val_acc:.4f}")
599
+
600
+ progress(0.95, desc="Saving model...")
601
+ timestamp = int(datetime.now().timestamp())
602
+ model_path = os.path.join(MODEL_DIR, f"longformer_era_classifier_{timestamp}")
603
+ MODEL.save_pretrained(model_path)
604
+ TOKENIZER.save_pretrained(model_path)
605
+
606
+ metrics_df = pd.DataFrame({
607
+ "Epoch": list(range(1, epochs + 1)),
608
+ "Training Loss": train_losses,
609
+ "Validation Accuracy": [f"{acc:.4f}" for acc in val_accuracies]
610
+ })
611
+
612
+ summary = f"✅ Training Complete!\nFinal Val Acc: {val_accuracies[-1]:.4f}\nModel saved to: {model_path}"
613
+
614
+ return summary, metrics_df, model_path
615
+
616
+ except RuntimeError as e:
617
+ if 'out of memory' in str(e):
618
+ if DEVICE == "cuda": torch.cuda.empty_cache()
619
+ return f"❌ Training error: CUDA Out Of Memory. Try reducing the 'Batch Size' slider to 1, or increase 'Gradient Accumulation Steps'. Error: {str(e)}", None, None
620
+ return f"❌ Training error: {str(e)}", None, None
621
+ except Exception as e:
622
+ return f"❌ Training error: {str(e)}", None, None
623
+
624
+ # ============================================================================
625
+ # TAB 3: TESTING (No changes needed)
626
+ # ============================================================================
627
+
628
+ def load_trained_model(model_path):
629
+ global MODEL, TOKENIZER
630
+
631
+ try:
632
+ TOKENIZER = AutoTokenizer.from_pretrained(model_path)
633
+ MODEL = AutoModelForSequenceClassification.from_pretrained(model_path)
634
+ MODEL.to(DEVICE)
635
+ MODEL.eval()
636
+ return f"✅ Model loaded successfully from {model_path}"
637
+ except Exception as e:
638
+ return f"❌ Error loading model: {str(e)}"
639
+
640
+ def predict_era(text, model_path):
641
+ global MODEL, TOKENIZER
642
+
643
+ if not MODEL or not TOKENIZER:
644
+ if model_path and os.path.exists(model_path):
645
+ load_result = load_trained_model(model_path)
646
+ if "Error" in load_result:
647
+ return load_result, None
648
+ else:
649
+ return "❌ No model loaded. Please train a model first or provide a valid model path.", None
650
+
651
+ try:
652
+ encoding = TOKENIZER(
653
+ text,
654
+ add_special_tokens=True,
655
+ max_length=4096,
656
+ padding='max_length',
657
+ truncation=True,
658
+ return_tensors='pt'
659
+ )
660
+
661
+ input_ids = encoding['input_ids'].to(DEVICE)
662
+ attention_mask = encoding['attention_mask'].to(DEVICE)
663
+
664
+ with torch.no_grad():
665
+ with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
666
+ outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
667
+ logits = outputs.logits
668
+ probabilities = torch.softmax(logits, dim=1)[0]
669
+ predicted_class = torch.argmax(probabilities).item()
670
+
671
+ top_3_probs, top_3_indices = torch.topk(probabilities, 3)
672
+
673
+ results = []
674
+ for idx, prob in zip(top_3_indices, top_3_probs):
675
+ era_label = ID_TO_LABEL[idx.item()]
676
+ confidence = prob.item() * 100
677
+ results.append({
678
+ "Era": era_label,
679
+ "Confidence": f"{confidence:.2f}%"
680
+ })
681
+
682
+ predicted_era = ID_TO_LABEL[predicted_class]
683
+ result_text = f"🎯 **Predicted Era:** {predicted_era}\n\n**Confidence:** {probabilities[predicted_class].item()*100:.2f}%"
684
+
685
+ return result_text, pd.DataFrame(results)
686
+
687
+ except Exception as e:
688
+ return f"❌ Prediction error: {str(e)}", None
689
+
690
+ # ============================================================================
691
+ # GRADIO UI
692
+ # ============================================================================
693
+
694
+ with gr.Blocks(title="Complete ML Pipeline") as demo:
695
+ gr.Markdown("# 📚 Complete ML Pipeline: Dataset Generation, Training & Testing (RTX 4080 Super Optimized)")
696
+
697
+ with gr.Tabs():
698
+ # TAB 1: Dataset Generation
699
+ with gr.Tab("📊 Dataset Generation"):
700
+ gr.Markdown("## Generate Historical Text Dataset")
701
+ gr.Markdown("""
702
+ **DATA QUALITY FIX:** Contemporary Era (`9_...`) now has lower length requirements and a less strict quality check to compensate for scarce open-source post-1990 data.
703
+ """)
704
+
705
+ with gr.Row():
706
+ dataset_slider = gr.Slider(10, 500, step=10, value=100, label="Total Books to Collect (Max 500)")
707
+ generate_btn = gr.Button("🚀 Generate Dataset (New Data Quality)", variant="primary", size="lg")
708
+
709
+ dataset_download = gr.File(label="📥 Download Dataset ZIP")
710
+
711
+ with gr.Row():
712
+ with gr.Column():
713
+ gr.Markdown("### General Summary")
714
+ gen_stats = gr.Dataframe()
715
+ with gr.Column():
716
+ gr.Markdown("### Class Balance Check")
717
+ split_stats = gr.Dataframe()
718
+
719
+ generate_btn.click(
720
+ generate_dataset,
721
+ inputs=[dataset_slider],
722
+ outputs=[dataset_download, gen_stats, split_stats]
723
+ )
724
+
725
+ # TAB 2: Training
726
+ with gr.Tab("🎓 Model Training"):
727
+ gr.Markdown("## Train Longformer Era Classifier")
728
+ gr.Markdown(f"""
729
+ **GPU OPTIMIZED:** Training now uses **Automatic Mixed Precision (FP16/AMP)** for the RTX 4080 Super.
730
+ With 16GB VRAM, you can use a higher **Batch Size** (e.g., 4 or 8) and often set **Gradient Accumulation Steps** to 1.
731
+ """)
732
+
733
+ with gr.Row():
734
+ with gr.Column():
735
+ train_dataset_path = gr.Textbox(
736
+ label="Dataset Path",
737
+ value=os.path.join(DATASET_DIR, "longformer_dataset.jsonl"),
738
+ placeholder="Path to dataset JSONL file"
739
+ )
740
+ train_epochs = gr.Slider(1, 10, step=1, value=3, label="Epochs")
741
+ train_batch = gr.Slider(1, 16, step=1, value=4, label="Batch Size (Memory Control)")
742
+ train_accum = gr.Slider(1, 16, step=1, value=1, label="Gradient Accumulation Steps (Effective Batch Size)")
743
+ train_lr = gr.Number(value=2e-5, label="Learning Rate")
744
+ train_btn = gr.Button("🏋️ Start Training", variant="primary", size="lg")
745
+
746
+ with gr.Column():
747
+ train_output = gr.Textbox(label="Training Status", lines=8)
748
+ train_metrics = gr.Dataframe(label="Training Metrics")
749
+ model_path_output = gr.Textbox(label="Saved Model Path")
750
+
751
+ train_btn.click(
752
+ train_model,
753
+ inputs=[train_dataset_path, train_epochs, train_batch, train_lr, train_accum],
754
+ outputs=[train_output, train_metrics, model_path_output]
755
+ )
756
+
757
+ # TAB 3: Testing
758
+ with gr.Tab("🧪 Model Testing"):
759
+ gr.Markdown("## Test Era Classification (FP16/AMP Inference)")
760
+
761
+ with gr.Row():
762
+ with gr.Column():
763
+ test_model_path = gr.Textbox(
764
+ label="Model Path (optional - uses last trained model if empty)",
765
+ placeholder="trained_models/longformer_era_classifier_..."
766
+ )
767
+ test_input = gr.Textbox(
768
+ label="Input Text",
769
+ lines=10,
770
+ placeholder="Paste historical text here...\n\nExample: 'When that Aprille with his shoures soote, The droghte of Marche hath perced to the roote...'"
771
+ )
772
+ test_btn = gr.Button("🔍 Predict Era", variant="primary", size="lg")
773
+
774
+ with gr.Column():
775
+ test_result = gr.Markdown(label="Prediction Result")
776
+ test_probabilities = gr.Dataframe(label="Top 3 Predictions")
777
+
778
+ # Sample texts
779
+ gr.Markdown("### 📝 Try Sample Texts")
780
+ with gr.Row():
781
+ sample1 = gr.Button("Medieval Sample")
782
+ sample2 = gr.Button("Victorian Sample")
783
+ sample3 = gr.Button("Contemporary Sample")
784
+
785
+ def load_medieval():
786
+ return "Hwæt! We Gardena in geardagum, þeodcyninga, þrym gefrunon, hu ða æþelingas ellen fremedon."
787
+
788
+ def load_victorian():
789
+ return "It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife."
790
+
791
+ def load_contemporary():
792
+ return "The internet has fundamentally transformed how we communicate, work, and access information in the digital age."
793
+
794
+ sample1.click(load_medieval, outputs=[test_input])
795
+ sample2.click(load_victorian, outputs=[test_input])
796
+ sample3.click(load_contemporary, outputs=[test_input])
797
+
798
+ test_btn.click(
799
+ predict_era,
800
+ inputs=[test_input, test_model_path],
801
+ outputs=[test_result, test_probabilities]
802
+ )
803
+
804
+ gr.Markdown("---")
805
+ gr.Markdown(f"**Device:** {DEVICE} | **Status:** {'✅ CUDA/FP16 Ready' if DEVICE == 'cuda' else '⚠️ CPU Mode'} | **Model:** Longformer-base-4096")
806
+
807
+ if __name__ == "__main__":
808
+ demo.launch(ssr_mode=False)