Spaces:
Paused
Paused
resolving bugs
Browse files
app.py
CHANGED
|
@@ -107,7 +107,7 @@ class HTMLDataset(Dataset):
|
|
| 107 |
}
|
| 108 |
|
| 109 |
|
| 110 |
-
def train_model(
|
| 111 |
# Generate synthetic dataset
|
| 112 |
dataset = generate_dataset(num_samples=1000)
|
| 113 |
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)
|
|
@@ -129,10 +129,10 @@ def train_model(progress=gr.Progress()):
|
|
| 129 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 130 |
model.to(device)
|
| 131 |
|
| 132 |
-
for epoch in
|
| 133 |
model.train()
|
| 134 |
train_loss = 0
|
| 135 |
-
for batch in train_dataloader:
|
| 136 |
input_ids = batch['input_ids'].to(device)
|
| 137 |
attention_mask = batch['attention_mask'].to(device)
|
| 138 |
labels = batch['labels'].to(device)
|
|
@@ -160,7 +160,7 @@ def train_model(progress=gr.Progress()):
|
|
| 160 |
avg_train_loss = train_loss / len(train_dataloader)
|
| 161 |
avg_val_loss = val_loss / len(val_dataloader)
|
| 162 |
|
| 163 |
-
|
| 164 |
|
| 165 |
return model, tokenizer
|
| 166 |
|
|
|
|
| 107 |
}
|
| 108 |
|
| 109 |
|
| 110 |
+
def train_model():
|
| 111 |
# Generate synthetic dataset
|
| 112 |
dataset = generate_dataset(num_samples=1000)
|
| 113 |
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)
|
|
|
|
| 129 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 130 |
model.to(device)
|
| 131 |
|
| 132 |
+
for epoch in range(EPOCHS):
|
| 133 |
model.train()
|
| 134 |
train_loss = 0
|
| 135 |
+
for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}"):
|
| 136 |
input_ids = batch['input_ids'].to(device)
|
| 137 |
attention_mask = batch['attention_mask'].to(device)
|
| 138 |
labels = batch['labels'].to(device)
|
|
|
|
| 160 |
avg_train_loss = train_loss / len(train_dataloader)
|
| 161 |
avg_val_loss = val_loss / len(val_dataloader)
|
| 162 |
|
| 163 |
+
print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
|
| 164 |
|
| 165 |
return model, tokenizer
|
| 166 |
|