Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -146,6 +146,7 @@ def load_model(path):
|
|
| 146 |
|
| 147 |
# --------- Gradio Functions --------- #
|
| 148 |
def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
|
|
| 149 |
global font_path, ocr_model
|
| 150 |
|
| 151 |
# Save the uploaded font file
|
|
@@ -159,41 +160,54 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
| 159 |
dataset = OCRDataset(font_path)
|
| 160 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
| 161 |
|
| 162 |
-
# Visualize one sample
|
| 163 |
img, label = dataset[0]
|
| 164 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
| 165 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
| 166 |
plt.show()
|
| 167 |
|
| 168 |
-
# Initialize model
|
| 169 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
| 170 |
-
criterion = nn.CTCLoss(blank=
|
| 171 |
-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
|
| 172 |
|
| 173 |
# Training loop
|
| 174 |
for epoch in range(epochs):
|
|
|
|
|
|
|
|
|
|
| 175 |
for img, targets, target_lengths in dataloader:
|
| 176 |
img = img.to(device)
|
| 177 |
targets = targets.to(device)
|
| 178 |
target_lengths = target_lengths.to(device)
|
| 179 |
|
| 180 |
-
output = model(img)
|
| 181 |
-
batch_size = img.size(0)
|
| 182 |
seq_len = output.size(1)
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
loss = criterion(output.log_softmax(2).transpose(0, 1), targets, input_lengths, target_lengths)
|
| 186 |
optimizer.zero_grad()
|
| 187 |
loss.backward()
|
| 188 |
optimizer.step()
|
| 189 |
|
| 190 |
-
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
save_model(model, model_name)
|
| 195 |
ocr_model = model
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
|
|
|
|
| 146 |
|
| 147 |
# --------- Gradio Functions --------- #
|
| 148 |
def train_model(font_file, epochs=100, learning_rate=0.001):
|
| 149 |
+
import time
|
| 150 |
global font_path, ocr_model
|
| 151 |
|
| 152 |
# Save the uploaded font file
|
|
|
|
| 160 |
dataset = OCRDataset(font_path)
|
| 161 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
| 162 |
|
| 163 |
+
# Visualize one sample
|
| 164 |
img, label = dataset[0]
|
| 165 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
| 166 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
| 167 |
plt.show()
|
| 168 |
|
| 169 |
+
# Initialize model, loss, optimizer, scheduler
|
| 170 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
| 171 |
+
criterion = nn.CTCLoss(blank=BLANK_IDX)
|
| 172 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 173 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
| 174 |
|
| 175 |
# Training loop
|
| 176 |
for epoch in range(epochs):
|
| 177 |
+
model.train()
|
| 178 |
+
running_loss = 0.0
|
| 179 |
+
|
| 180 |
for img, targets, target_lengths in dataloader:
|
| 181 |
img = img.to(device)
|
| 182 |
targets = targets.to(device)
|
| 183 |
target_lengths = target_lengths.to(device)
|
| 184 |
|
| 185 |
+
output = model(img) # [B, T, C]
|
|
|
|
| 186 |
seq_len = output.size(1)
|
| 187 |
+
batch_size = img.size(0)
|
| 188 |
+
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
|
| 189 |
+
|
| 190 |
+
log_probs = output.log_softmax(2).transpose(0, 1) # [T, B, C]
|
| 191 |
+
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
| 192 |
|
|
|
|
| 193 |
optimizer.zero_grad()
|
| 194 |
loss.backward()
|
| 195 |
optimizer.step()
|
| 196 |
|
| 197 |
+
running_loss += loss.item()
|
| 198 |
|
| 199 |
+
avg_loss = running_loss / len(dataloader)
|
| 200 |
+
scheduler.step(avg_loss)
|
| 201 |
+
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
|
| 202 |
+
|
| 203 |
+
# Save the trained model
|
| 204 |
+
timestamp = time.strftime("%Y%m%d%H%M%S")
|
| 205 |
+
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
|
| 206 |
save_model(model, model_name)
|
| 207 |
ocr_model = model
|
| 208 |
+
|
| 209 |
+
return f"✅ Training complete! Model saved as '{model_name}'"
|
| 210 |
+
|
| 211 |
|
| 212 |
|
| 213 |
|