Spaces:
Paused
Paused
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import BertTokenizer, BertForTokenClassification, AdamW | |
| from sklearn.model_selection import train_test_split | |
| import gradio as gr | |
| import random | |
| from faker import Faker | |
| import html | |
| import json | |
| import numpy as np | |
| from tqdm import tqdm | |
| import os | |
| # Constants | |
| MAX_LENGTH = 512 | |
| BATCH_SIZE = 16 | |
| EPOCHS = 5 | |
| LEARNING_RATE = 2e-5 | |
| fake = Faker() | |
| def generate_employee(): | |
| name = fake.name() | |
| job = fake.job() | |
| ext = f"ext. {random.randint(1000, 9999)}" | |
| email = f"{name.lower().replace(' ', '.')}@example.com" | |
| return name, job, ext, email | |
| def generate_html_content(num_employees=3): | |
| employees = [generate_employee() for _ in range(num_employees)] | |
| html_content = f""" | |
| <html> | |
| <head> | |
| <title>Employee Directory</title> | |
| </head> | |
| <body> | |
| <div class="row ts-three-column-row standard-row"> | |
| """ | |
| for name, job, ext, email in employees: | |
| html_content += f""" | |
| <div class="column ts-three-column"> | |
| <div class="block"> | |
| <div class="text-block" style="text-align: center;"> | |
| <p> | |
| <strong>{html.escape(name)}</strong><br> | |
| <span style="font-size: 16px">{html.escape(job)}</span><br> | |
| <span style="font-size: 16px">{html.escape(ext)}</span><br> | |
| <a href="mailto:{html.escape(email)}">Send Email</a> | |
| </p> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| html_content += """ | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return html_content, employees | |
| def generate_dataset(num_samples=1000): | |
| dataset = [] | |
| for _ in range(num_samples): | |
| html_content, employees = generate_html_content() | |
| dataset.append((html_content, employees)) | |
| return dataset | |
| class HTMLDataset(Dataset): | |
| def __init__(self, data, tokenizer, max_length): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.label_map = {"O": 0, "B-NAME": 1, "I-NAME": 2, "B-JOB": 3, "I-JOB": 4, "B-EXT": 5, "I-EXT": 6, | |
| "B-EMAIL": 7, "I-EMAIL": 8} | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| html, employees = self.data[idx] | |
| encoding = self.tokenizer.encode_plus( | |
| html, | |
| add_special_tokens=True, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| labels = self.create_labels(encoding['input_ids'][0], employees) | |
| return { | |
| 'input_ids': encoding['input_ids'].flatten(), | |
| 'attention_mask': encoding['attention_mask'].flatten(), | |
| 'labels': torch.tensor(labels, dtype=torch.long) | |
| } | |
| def create_labels(self, tokens, employees): | |
| labels = [0] * len(tokens) # Initialize all labels as "O" | |
| for name, job, ext, email in employees: | |
| self.label_sequence(tokens, name, "NAME", labels) | |
| self.label_sequence(tokens, job, "JOB", labels) | |
| self.label_sequence(tokens, ext, "EXT", labels) | |
| self.label_sequence(tokens, email, "EMAIL", labels) | |
| return labels | |
| def label_sequence(self, tokens, text, label_type, labels): | |
| text_tokens = self.tokenizer.encode(text, add_special_tokens=False) | |
| for i in range(len(tokens) - len(text_tokens) + 1): | |
| if tokens[i:i + len(text_tokens)] == text_tokens: | |
| labels[i] = self.label_map[f"B-{label_type}"] | |
| for j in range(1, len(text_tokens)): | |
| labels[i + j] = self.label_map[f"I-{label_type}"] | |
| break | |
| def train_model(): | |
| # Generate synthetic dataset | |
| dataset = generate_dataset(num_samples=1000) | |
| train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42) | |
| # Initialize tokenizer and model | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=9) | |
| # Prepare datasets and dataloaders | |
| train_dataset = HTMLDataset(train_data, tokenizer, MAX_LENGTH) | |
| val_dataset = HTMLDataset(val_data, tokenizer, MAX_LENGTH) | |
| train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE) | |
| # Initialize optimizer | |
| optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) | |
| # Training loop | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| for epoch in range(EPOCHS): | |
| model.train() | |
| train_loss = 0 | |
| for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}"): | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['labels'].to(device) | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs.loss | |
| train_loss += loss.item() | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Validation | |
| model.eval() | |
| val_loss = 0 | |
| with torch.no_grad(): | |
| for batch in val_dataloader: | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['labels'].to(device) | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=labels) | |
| val_loss += outputs.loss.item() | |
| avg_train_loss = train_loss / len(train_dataloader) | |
| avg_val_loss = val_loss / len(val_dataloader) | |
| print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") | |
| return model, tokenizer | |
| def extract_content(html, model, tokenizer): | |
| model.eval() | |
| encoding = tokenizer.encode_plus( | |
| html, | |
| add_special_tokens=True, | |
| max_length=MAX_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| input_ids = encoding['input_ids'].to(model.device) | |
| attention_mask = encoding['attention_mask'].to(model.device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| predictions = outputs.logits.argmax(dim=2) | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| label_map = {0: "O", 1: "B-NAME", 2: "I-NAME", 3: "B-JOB", 4: "I-JOB", 5: "B-EXT", 6: "I-EXT", 7: "B-EMAIL", | |
| 8: "I-EMAIL"} | |
| extracted_info = [] | |
| current_entity = {"type": None, "value": ""} | |
| for token, prediction in zip(tokens, predictions[0]): | |
| if token == "[PAD]": | |
| break | |
| label = label_map[prediction.item()] | |
| if label.startswith("B-"): | |
| if current_entity["type"]: | |
| extracted_info.append(current_entity) | |
| current_entity = {"type": label[2:], "value": token} | |
| elif label.startswith("I-"): | |
| if current_entity["type"] == label[2:]: | |
| current_entity["value"] += " " + token | |
| elif label == "O": | |
| if current_entity["type"]: | |
| extracted_info.append(current_entity) | |
| current_entity = {"type": None, "value": ""} | |
| if current_entity["type"]: | |
| extracted_info.append(current_entity) | |
| # Group entities by employee | |
| employees = [] | |
| current_employee = {} | |
| for entity in extracted_info: | |
| if entity["type"] == "NAME": | |
| if current_employee: | |
| employees.append(current_employee) | |
| current_employee = {"name": entity["value"]} | |
| else: | |
| current_employee[entity["type"].lower()] = entity["value"] | |
| if current_employee: | |
| employees.append(current_employee) | |
| return json.dumps(employees, indent=2) | |
| def test_model(html_input, model, tokenizer): | |
| result = extract_content(html_input, model, tokenizer) | |
| return result | |
| def gradio_interface(html_input, test_type): | |
| global model, tokenizer | |
| if test_type == "Custom Input": | |
| result = test_model(html_input, model, tokenizer) | |
| return html_input, result | |
| elif test_type == "Generate Random HTML": | |
| random_html, _ = generate_html_content() | |
| result = test_model(random_html, model, tokenizer) | |
| return random_html, result | |
| # Check if the model is already trained and saved | |
| if os.path.exists('model.pth') and os.path.exists('tokenizer'): | |
| print("Loading pre-trained model...") | |
| model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=9) | |
| model.load_state_dict(torch.load('model.pth')) | |
| tokenizer = BertTokenizer.from_pretrained('tokenizer') | |
| else: | |
| print("Training new model...") | |
| model, tokenizer = train_model() | |
| # Save the model and tokenizer | |
| torch.save(model.state_dict(), 'model.pth') | |
| tokenizer.save_pretrained('tokenizer') | |
| print("Launching Gradio interface...") | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Textbox(lines=10, label="Input HTML"), | |
| gr.Radio(["Custom Input", "Generate Random HTML"], label="Test Type", value="Custom Input") | |
| ], | |
| outputs=[ | |
| gr.Textbox(lines=10, label="HTML Content"), | |
| gr.Textbox(label="Extracted Information (JSON)") | |
| ], | |
| title="HTML Content Extractor", | |
| description="Enter HTML content or generate random HTML to test the model. The model will extract employee information and return it in JSON format." | |
| ) | |
| iface.launch() |