faheem66 commited on
Commit
11402c8
·
1 Parent(s): d2d9304

added the bert model with synthetic data for initial training and testing

Browse files
Files changed (2) hide show
  1. app.py +223 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import BertTokenizer, BertForSequenceClassification, AdamW
4
+ from sklearn.model_selection import train_test_split
5
+ import gradio as gr
6
+ import random
7
+ from faker import Faker
8
+ import html
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ # Constants
13
+ MAX_LENGTH = 512
14
+ BATCH_SIZE = 16
15
+ EPOCHS = 5
16
+ LEARNING_RATE = 2e-5
17
+
18
+ fake = Faker()
19
+
20
+
21
+ def generate_employee():
22
+ name = fake.name()
23
+ job = fake.job()
24
+ ext = f"ext. {random.randint(1000, 9999)}"
25
+ email = f"{name.lower().replace(' ', '.')}@example.com"
26
+ return name, job, ext, email
27
+
28
+
29
+ def generate_html_content(num_employees=9):
30
+ employees = [generate_employee() for _ in range(num_employees)]
31
+
32
+ html_content = f"""
33
+ <html>
34
+ <head>
35
+ <title>Employee Directory</title>
36
+ </head>
37
+ <body>
38
+ <div class="row ts-three-column-row standard-row">
39
+ """
40
+
41
+ for i, (name, job, ext, email) in enumerate(employees):
42
+ if i % 3 == 0:
43
+ html_content += '<div class="column ts-three-column">'
44
+
45
+ html_content += f"""
46
+ <div class="block">
47
+ <div class="text-block" style="text-align: center;">
48
+ <p>
49
+ <strong>{html.escape(name)}</strong><br>
50
+ <span style="font-size: 16px">{html.escape(job)}</span><br>
51
+ <span style="font-size: 16px">{html.escape(ext)}</span><br>
52
+ <a href="mailto:{html.escape(email)}">Send Email</a>
53
+ </p>
54
+ </div>
55
+ </div>
56
+ """
57
+
58
+ if (i + 1) % 3 == 0 or i == len(employees) - 1:
59
+ html_content += '</div>'
60
+
61
+ html_content += """
62
+ </div>
63
+ </body>
64
+ </html>
65
+ """
66
+
67
+ return html_content
68
+
69
+
70
+ def generate_dataset(num_samples=1000):
71
+ dataset = []
72
+ for _ in range(num_samples):
73
+ html_content = generate_html_content()
74
+ employees = []
75
+ for line in html_content.split('\n'):
76
+ if '<strong>' in line:
77
+ name = line.split('<strong>')[1].split('</strong>')[0]
78
+ elif '<span style="font-size: 16px">' in line:
79
+ if 'ext.' in line:
80
+ ext = line.split('<span style="font-size: 16px">')[1].split('</span>')[0]
81
+ else:
82
+ job = line.split('<span style="font-size: 16px">')[1].split('</span>')[0]
83
+ elif '<a href="mailto:' in line:
84
+ email = line.split('<a href="mailto:')[1].split('">')[0]
85
+ employees.append(f"{name}\n{job}\n{ext}\n{email}")
86
+
87
+ dataset.append((html_content, '\n\n'.join(employees)))
88
+
89
+ return dataset
90
+
91
+
92
+ class HTMLDataset(Dataset):
93
+ def __init__(self, data, tokenizer, max_length):
94
+ self.data = data
95
+ self.tokenizer = tokenizer
96
+ self.max_length = max_length
97
+
98
+ def __len__(self):
99
+ return len(self.data)
100
+
101
+ def __getitem__(self, idx):
102
+ html, extracted = self.data[idx]
103
+ encoding = self.tokenizer.encode_plus(
104
+ html,
105
+ add_special_tokens=True,
106
+ max_length=self.max_length,
107
+ return_token_type_ids=False,
108
+ padding='max_length',
109
+ truncation=True,
110
+ return_attention_mask=True,
111
+ return_tensors='pt',
112
+ )
113
+
114
+ return {
115
+ 'input_ids': encoding['input_ids'].flatten(),
116
+ 'attention_mask': encoding['attention_mask'].flatten(),
117
+ 'labels': torch.tensor(extracted, dtype=torch.float)
118
+ }
119
+
120
+
121
+ def train_model(progress=gr.Progress()):
122
+ # Generate synthetic dataset
123
+ dataset = generate_dataset(num_samples=1000)
124
+ train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)
125
+
126
+ # Initialize tokenizer and model
127
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
128
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=1)
129
+
130
+ # Prepare datasets and dataloaders
131
+ train_dataset = HTMLDataset(train_data, tokenizer, MAX_LENGTH)
132
+ val_dataset = HTMLDataset(val_data, tokenizer, MAX_LENGTH)
133
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
134
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
135
+
136
+ # Initialize optimizer
137
+ optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
138
+
139
+ # Training loop
140
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
141
+ model.to(device)
142
+
143
+ for epoch in progress.tqdm(range(EPOCHS), desc="Training Progress"):
144
+ model.train()
145
+ train_loss = 0
146
+ for batch in train_dataloader:
147
+ input_ids = batch['input_ids'].to(device)
148
+ attention_mask = batch['attention_mask'].to(device)
149
+ labels = batch['labels'].to(device)
150
+
151
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
152
+ loss = outputs.loss
153
+ train_loss += loss.item()
154
+
155
+ loss.backward()
156
+ optimizer.step()
157
+ optimizer.zero_grad()
158
+
159
+ # Validation
160
+ model.eval()
161
+ val_loss = 0
162
+ with torch.no_grad():
163
+ for batch in val_dataloader:
164
+ input_ids = batch['input_ids'].to(device)
165
+ attention_mask = batch['attention_mask'].to(device)
166
+ labels = batch['labels'].to(device)
167
+
168
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
169
+ val_loss += outputs.loss.item()
170
+
171
+ avg_train_loss = train_loss / len(train_dataloader)
172
+ avg_val_loss = val_loss / len(val_dataloader)
173
+
174
+ progress(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
175
+
176
+ return model, tokenizer
177
+
178
+
179
+ def extract_content(html, model, tokenizer):
180
+ model.eval()
181
+ encoding = tokenizer.encode_plus(
182
+ html,
183
+ add_special_tokens=True,
184
+ max_length=MAX_LENGTH,
185
+ return_token_type_ids=False,
186
+ padding='max_length',
187
+ truncation=True,
188
+ return_attention_mask=True,
189
+ return_tensors='pt',
190
+ )
191
+
192
+ input_ids = encoding['input_ids'].to(model.device)
193
+ attention_mask = encoding['attention_mask'].to(model.device)
194
+
195
+ with torch.no_grad():
196
+ outputs = model(input_ids, attention_mask=attention_mask)
197
+ predictions = outputs.logits.sigmoid().cpu().numpy()
198
+
199
+ # Extract content based on predictions
200
+ # This is a placeholder implementation and needs to be adjusted based on your specific use case
201
+ extracted_content = f"Extracted content (confidence: {predictions[0][0]:.2f})"
202
+ return extracted_content
203
+
204
+
205
+ def gradio_interface(html_input):
206
+ global trained_model, trained_tokenizer
207
+ extracted_content = extract_content(html_input, trained_model, trained_tokenizer)
208
+ return extracted_content
209
+
210
+
211
+ print("Starting training process...")
212
+ trained_model, trained_tokenizer = train_model()
213
+ print("Training completed. Launching Gradio interface...")
214
+
215
+ iface = gr.Interface(
216
+ fn=gradio_interface,
217
+ inputs=gr.Textbox(lines=10, label="Input HTML"),
218
+ outputs=gr.Textbox(label="Extracted Content"),
219
+ title="HTML Content Extractor",
220
+ description="Enter HTML content to extract information."
221
+ )
222
+
223
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ scikit-learn
4
+ faker
5
+ gradio
6
+ tqdm