bert_testing_1 / app.py
faheem66's picture
changed training approach
32520ce
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForTokenClassification, AdamW
from sklearn.model_selection import train_test_split
import gradio as gr
import random
from faker import Faker
import html
import json
import numpy as np
from tqdm import tqdm
import os
# Constants
MAX_LENGTH = 512
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 2e-5
fake = Faker()
def generate_employee():
name = fake.name()
job = fake.job()
ext = f"ext. {random.randint(1000, 9999)}"
email = f"{name.lower().replace(' ', '.')}@example.com"
return name, job, ext, email
def generate_html_content(num_employees=3):
employees = [generate_employee() for _ in range(num_employees)]
html_content = f"""
<html>
<head>
<title>Employee Directory</title>
</head>
<body>
<div class="row ts-three-column-row standard-row">
"""
for name, job, ext, email in employees:
html_content += f"""
<div class="column ts-three-column">
<div class="block">
<div class="text-block" style="text-align: center;">
<p>
<strong>{html.escape(name)}</strong><br>
<span style="font-size: 16px">{html.escape(job)}</span><br>
<span style="font-size: 16px">{html.escape(ext)}</span><br>
<a href="mailto:{html.escape(email)}">Send Email</a>
</p>
</div>
</div>
</div>
"""
html_content += """
</div>
</body>
</html>
"""
return html_content, employees
def generate_dataset(num_samples=1000):
dataset = []
for _ in range(num_samples):
html_content, employees = generate_html_content()
dataset.append((html_content, employees))
return dataset
class HTMLDataset(Dataset):
def __init__(self, data, tokenizer, max_length):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.label_map = {"O": 0, "B-NAME": 1, "I-NAME": 2, "B-JOB": 3, "I-JOB": 4, "B-EXT": 5, "I-EXT": 6,
"B-EMAIL": 7, "I-EMAIL": 8}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
html, employees = self.data[idx]
encoding = self.tokenizer.encode_plus(
html,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
labels = self.create_labels(encoding['input_ids'][0], employees)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(labels, dtype=torch.long)
}
def create_labels(self, tokens, employees):
labels = [0] * len(tokens) # Initialize all labels as "O"
for name, job, ext, email in employees:
self.label_sequence(tokens, name, "NAME", labels)
self.label_sequence(tokens, job, "JOB", labels)
self.label_sequence(tokens, ext, "EXT", labels)
self.label_sequence(tokens, email, "EMAIL", labels)
return labels
def label_sequence(self, tokens, text, label_type, labels):
text_tokens = self.tokenizer.encode(text, add_special_tokens=False)
for i in range(len(tokens) - len(text_tokens) + 1):
if tokens[i:i + len(text_tokens)] == text_tokens:
labels[i] = self.label_map[f"B-{label_type}"]
for j in range(1, len(text_tokens)):
labels[i + j] = self.label_map[f"I-{label_type}"]
break
def train_model():
# Generate synthetic dataset
dataset = generate_dataset(num_samples=1000)
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=9)
# Prepare datasets and dataloaders
train_dataset = HTMLDataset(train_data, tokenizer, MAX_LENGTH)
val_dataset = HTMLDataset(val_data, tokenizer, MAX_LENGTH)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(EPOCHS):
model.train()
train_loss = 0
for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
train_loss += loss.item()
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
val_loss += outputs.loss.item()
avg_train_loss = train_loss / len(train_dataloader)
avg_val_loss = val_loss / len(val_dataloader)
print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
return model, tokenizer
def extract_content(html, model, tokenizer):
model.eval()
encoding = tokenizer.encode_plus(
html,
add_special_tokens=True,
max_length=MAX_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].to(model.device)
attention_mask = encoding['attention_mask'].to(model.device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
predictions = outputs.logits.argmax(dim=2)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
label_map = {0: "O", 1: "B-NAME", 2: "I-NAME", 3: "B-JOB", 4: "I-JOB", 5: "B-EXT", 6: "I-EXT", 7: "B-EMAIL",
8: "I-EMAIL"}
extracted_info = []
current_entity = {"type": None, "value": ""}
for token, prediction in zip(tokens, predictions[0]):
if token == "[PAD]":
break
label = label_map[prediction.item()]
if label.startswith("B-"):
if current_entity["type"]:
extracted_info.append(current_entity)
current_entity = {"type": label[2:], "value": token}
elif label.startswith("I-"):
if current_entity["type"] == label[2:]:
current_entity["value"] += " " + token
elif label == "O":
if current_entity["type"]:
extracted_info.append(current_entity)
current_entity = {"type": None, "value": ""}
if current_entity["type"]:
extracted_info.append(current_entity)
# Group entities by employee
employees = []
current_employee = {}
for entity in extracted_info:
if entity["type"] == "NAME":
if current_employee:
employees.append(current_employee)
current_employee = {"name": entity["value"]}
else:
current_employee[entity["type"].lower()] = entity["value"]
if current_employee:
employees.append(current_employee)
return json.dumps(employees, indent=2)
def test_model(html_input, model, tokenizer):
result = extract_content(html_input, model, tokenizer)
return result
def gradio_interface(html_input, test_type):
global model, tokenizer
if test_type == "Custom Input":
result = test_model(html_input, model, tokenizer)
return html_input, result
elif test_type == "Generate Random HTML":
random_html, _ = generate_html_content()
result = test_model(random_html, model, tokenizer)
return random_html, result
# Check if the model is already trained and saved
if os.path.exists('model.pth') and os.path.exists('tokenizer'):
print("Loading pre-trained model...")
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=9)
model.load_state_dict(torch.load('model.pth'))
tokenizer = BertTokenizer.from_pretrained('tokenizer')
else:
print("Training new model...")
model, tokenizer = train_model()
# Save the model and tokenizer
torch.save(model.state_dict(), 'model.pth')
tokenizer.save_pretrained('tokenizer')
print("Launching Gradio interface...")
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=10, label="Input HTML"),
gr.Radio(["Custom Input", "Generate Random HTML"], label="Test Type", value="Custom Input")
],
outputs=[
gr.Textbox(lines=10, label="HTML Content"),
gr.Textbox(label="Extracted Information (JSON)")
],
title="HTML Content Extractor",
description="Enter HTML content or generate random HTML to test the model. The model will extract employee information and return it in JSON format."
)
iface.launch()