Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,6 +38,10 @@ alphabet = Alphabet.build_alphabet(labels)
|
|
| 38 |
|
| 39 |
# Now initialize decoder correctly
|
| 40 |
decoder = BeamSearchDecoderCTC(alphabet)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# --------- Dataset --------- #
|
| 43 |
class OCRDataset(Dataset):
|
|
@@ -155,7 +159,11 @@ def custom_collate_fn(batch):
|
|
| 155 |
|
| 156 |
# --------- Model Save/Load --------- #
|
| 157 |
def list_saved_models():
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def save_model(model, path):
|
|
@@ -177,14 +185,18 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
| 177 |
import time
|
| 178 |
global font_path, ocr_model
|
| 179 |
|
| 180 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
font_name = os.path.splitext(os.path.basename(font_file.name))[0]
|
| 182 |
-
font_path = f"./{font_name}.ttf"
|
| 183 |
with open(font_file.name, "rb") as uploaded:
|
| 184 |
with open(font_path, "wb") as f:
|
| 185 |
f.write(uploaded.read())
|
| 186 |
|
| 187 |
-
# Curriculum learning:
|
| 188 |
def get_dataset_for_epoch(epoch):
|
| 189 |
if epoch < epochs // 3:
|
| 190 |
label_len = (3, 4)
|
|
@@ -194,29 +206,27 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
| 194 |
label_len = (5, 7)
|
| 195 |
return OCRDataset(font_path, label_length_range=label_len)
|
| 196 |
|
| 197 |
-
# Visualize one sample
|
| 198 |
dataset = get_dataset_for_epoch(0)
|
| 199 |
-
img, label, _ = dataset[0]
|
| 200 |
-
|
| 201 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
| 202 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
| 203 |
plt.show()
|
| 204 |
|
| 205 |
-
#
|
| 206 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
| 207 |
criterion = nn.CTCLoss(blank=BLANK_IDX)
|
| 208 |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 209 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
| 210 |
|
| 211 |
for epoch in range(epochs):
|
| 212 |
-
# Load new dataset for current curriculum stage
|
| 213 |
dataset = get_dataset_for_epoch(epoch)
|
| 214 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
| 215 |
|
| 216 |
model.train()
|
| 217 |
running_loss = 0.0
|
| 218 |
|
| 219 |
-
#
|
| 220 |
if epoch < 5:
|
| 221 |
warmup_lr = learning_rate * 0.2
|
| 222 |
for param_group in optimizer.param_groups:
|
|
@@ -230,12 +240,12 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
| 230 |
targets = targets.to(device)
|
| 231 |
target_lengths = target_lengths.to(device)
|
| 232 |
|
| 233 |
-
output = model(img)
|
| 234 |
seq_len = output.size(1)
|
| 235 |
batch_size = img.size(0)
|
| 236 |
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
|
| 237 |
|
| 238 |
-
log_probs = output.log_softmax(2).transpose(0, 1)
|
| 239 |
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
| 240 |
|
| 241 |
optimizer.zero_grad()
|
|
@@ -248,13 +258,15 @@ def train_model(font_file, epochs=100, learning_rate=0.001):
|
|
| 248 |
scheduler.step(avg_loss)
|
| 249 |
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
|
| 250 |
|
| 251 |
-
# Save the model
|
| 252 |
timestamp = time.strftime("%Y%m%d%H%M%S")
|
| 253 |
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
ocr_model = model
|
|
|
|
| 256 |
|
| 257 |
-
return f"✅ Training complete! Model saved as '{model_name}'"
|
| 258 |
|
| 259 |
|
| 260 |
|
|
@@ -376,11 +388,11 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
| 376 |
global font_path
|
| 377 |
|
| 378 |
try:
|
| 379 |
-
if font_file:
|
| 380 |
-
font_path =
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
if font_path is None or not os.path.exists(font_path):
|
| 385 |
font = ImageFont.load_default()
|
| 386 |
else:
|
|
@@ -391,7 +403,6 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
| 391 |
images = []
|
| 392 |
|
| 393 |
for label in labels:
|
| 394 |
-
# Measure text size and calculate padded image dimensions
|
| 395 |
bbox = font.getbbox(label)
|
| 396 |
text_w = bbox[2] - bbox[0]
|
| 397 |
text_h = bbox[3] - bbox[1]
|
|
@@ -399,12 +410,10 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
| 399 |
img_w = text_w + pad * 2
|
| 400 |
img_h = text_h + pad * 2
|
| 401 |
|
| 402 |
-
# Create image and draw text
|
| 403 |
img = Image.new("L", (img_w, img_h), color=255)
|
| 404 |
draw = ImageDraw.Draw(img)
|
| 405 |
draw.text((pad, pad), label, font=font, fill=0)
|
| 406 |
|
| 407 |
-
# Save to ./labels/sanitized_label/timestamp.png
|
| 408 |
safe_label = sanitize_filename(label)
|
| 409 |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
| 410 |
label_dir = os.path.join("./labels", safe_label)
|
|
@@ -424,6 +433,10 @@ def generate_labels(font_file=None, num_labels: int = 25):
|
|
| 424 |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
|
| 425 |
return [error_img]
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
custom_css = """
|
| 428 |
#label-gallery .gallery-item img {
|
| 429 |
height: 43px; /* 32pt ≈ 43px */
|
|
@@ -444,7 +457,7 @@ custom_css = """
|
|
| 444 |
|
| 445 |
# --------- Updated Gradio UI with new tab --------- #
|
| 446 |
with gr.Blocks(css=custom_css) as demo:
|
| 447 |
-
with gr.Tab("
|
| 448 |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
|
| 449 |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
|
| 450 |
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
|
|
@@ -453,8 +466,28 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 453 |
|
| 454 |
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
-
|
|
|
|
| 458 |
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
|
| 459 |
refresh_btn = gr.Button("🔄 Refresh Models")
|
| 460 |
load_model_btn = gr.Button("Load Model") # <-- new button
|
|
@@ -472,23 +505,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 472 |
|
| 473 |
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
|
| 474 |
|
| 475 |
-
with gr.Tab("3. Generate Labels"):
|
| 476 |
-
font_file_labels = gr.File(label="Optional font for label image", file_types=[".ttf", ".otf"])
|
| 477 |
-
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
|
| 478 |
-
gen_button = gr.Button("Generate Label Grid")
|
| 479 |
-
|
| 480 |
-
gen_button.click(
|
| 481 |
-
fn=generate_labels,
|
| 482 |
-
inputs=[font_file_labels, num_labels],
|
| 483 |
-
outputs=gr.Gallery(
|
| 484 |
-
label="Generated Labels",
|
| 485 |
-
columns=16, # 16 tiles per row
|
| 486 |
-
object_fit="contain", # Maintain aspect ratio
|
| 487 |
-
height="100%", # Allow full app height
|
| 488 |
-
elem_id="label-gallery" # For CSS targeting
|
| 489 |
-
)
|
| 490 |
|
| 491 |
-
)
|
| 492 |
|
| 493 |
|
| 494 |
|
|
|
|
| 38 |
|
| 39 |
# Now initialize decoder correctly
|
| 40 |
decoder = BeamSearchDecoderCTC(alphabet)
|
| 41 |
+
# Ensure required directories exist at startup
|
| 42 |
+
os.makedirs("./fonts", exist_ok=True)
|
| 43 |
+
os.makedirs("./models", exist_ok=True)
|
| 44 |
+
os.makedirs("./labels", exist_ok=True)
|
| 45 |
|
| 46 |
# --------- Dataset --------- #
|
| 47 |
class OCRDataset(Dataset):
|
|
|
|
| 159 |
|
| 160 |
# --------- Model Save/Load --------- #
|
| 161 |
def list_saved_models():
|
| 162 |
+
model_dir = "./models"
|
| 163 |
+
if not os.path.exists(model_dir):
|
| 164 |
+
return []
|
| 165 |
+
return [f for f in os.listdir(model_dir) if f.endswith(".pth")]
|
| 166 |
+
|
| 167 |
|
| 168 |
|
| 169 |
def save_model(model, path):
|
|
|
|
| 185 |
import time
|
| 186 |
global font_path, ocr_model
|
| 187 |
|
| 188 |
+
# Ensure directories exist
|
| 189 |
+
os.makedirs("./fonts", exist_ok=True)
|
| 190 |
+
os.makedirs("./models", exist_ok=True)
|
| 191 |
+
|
| 192 |
+
# Save uploaded font to ./fonts
|
| 193 |
font_name = os.path.splitext(os.path.basename(font_file.name))[0]
|
| 194 |
+
font_path = f"./fonts/{font_name}.ttf"
|
| 195 |
with open(font_file.name, "rb") as uploaded:
|
| 196 |
with open(font_path, "wb") as f:
|
| 197 |
f.write(uploaded.read())
|
| 198 |
|
| 199 |
+
# Curriculum learning: label length grows over time
|
| 200 |
def get_dataset_for_epoch(epoch):
|
| 201 |
if epoch < epochs // 3:
|
| 202 |
label_len = (3, 4)
|
|
|
|
| 206 |
label_len = (5, 7)
|
| 207 |
return OCRDataset(font_path, label_length_range=label_len)
|
| 208 |
|
| 209 |
+
# Visualize one sample
|
| 210 |
dataset = get_dataset_for_epoch(0)
|
| 211 |
+
img, label, _ = dataset[0]
|
|
|
|
| 212 |
print("Label:", ''.join([IDX2CHAR[i.item()] for i in label]))
|
| 213 |
plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray')
|
| 214 |
plt.show()
|
| 215 |
|
| 216 |
+
# Model setup
|
| 217 |
model = OCRModel(num_classes=len(CHAR2IDX)).to(device)
|
| 218 |
criterion = nn.CTCLoss(blank=BLANK_IDX)
|
| 219 |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 220 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
| 221 |
|
| 222 |
for epoch in range(epochs):
|
|
|
|
| 223 |
dataset = get_dataset_for_epoch(epoch)
|
| 224 |
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
|
| 225 |
|
| 226 |
model.train()
|
| 227 |
running_loss = 0.0
|
| 228 |
|
| 229 |
+
# Warmup learning rate
|
| 230 |
if epoch < 5:
|
| 231 |
warmup_lr = learning_rate * 0.2
|
| 232 |
for param_group in optimizer.param_groups:
|
|
|
|
| 240 |
targets = targets.to(device)
|
| 241 |
target_lengths = target_lengths.to(device)
|
| 242 |
|
| 243 |
+
output = model(img)
|
| 244 |
seq_len = output.size(1)
|
| 245 |
batch_size = img.size(0)
|
| 246 |
input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device)
|
| 247 |
|
| 248 |
+
log_probs = output.log_softmax(2).transpose(0, 1)
|
| 249 |
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
| 250 |
|
| 251 |
optimizer.zero_grad()
|
|
|
|
| 258 |
scheduler.step(avg_loss)
|
| 259 |
print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}")
|
| 260 |
|
| 261 |
+
# Save the model to ./models
|
| 262 |
timestamp = time.strftime("%Y%m%d%H%M%S")
|
| 263 |
model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth"
|
| 264 |
+
model_path = os.path.join("./models", model_name)
|
| 265 |
+
save_model(model, model_path)
|
| 266 |
+
|
| 267 |
ocr_model = model
|
| 268 |
+
return f"✅ Training complete! Model saved as '{model_path}'"
|
| 269 |
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
|
|
|
|
| 388 |
global font_path
|
| 389 |
|
| 390 |
try:
|
| 391 |
+
if font_file and font_file != "None":
|
| 392 |
+
font_path = os.path.abspath(font_file)
|
| 393 |
+
else:
|
| 394 |
+
font_path = None
|
| 395 |
+
|
| 396 |
if font_path is None or not os.path.exists(font_path):
|
| 397 |
font = ImageFont.load_default()
|
| 398 |
else:
|
|
|
|
| 403 |
images = []
|
| 404 |
|
| 405 |
for label in labels:
|
|
|
|
| 406 |
bbox = font.getbbox(label)
|
| 407 |
text_w = bbox[2] - bbox[0]
|
| 408 |
text_h = bbox[3] - bbox[1]
|
|
|
|
| 410 |
img_w = text_w + pad * 2
|
| 411 |
img_h = text_h + pad * 2
|
| 412 |
|
|
|
|
| 413 |
img = Image.new("L", (img_w, img_h), color=255)
|
| 414 |
draw = ImageDraw.Draw(img)
|
| 415 |
draw.text((pad, pad), label, font=font, fill=0)
|
| 416 |
|
|
|
|
| 417 |
safe_label = sanitize_filename(label)
|
| 418 |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
| 419 |
label_dir = os.path.join("./labels", safe_label)
|
|
|
|
| 433 |
draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0))
|
| 434 |
return [error_img]
|
| 435 |
|
| 436 |
+
def list_fonts():
|
| 437 |
+
fonts = [f for f in os.listdir() if f.lower().endswith((".ttf", ".otf"))]
|
| 438 |
+
return ["None"] + fonts if fonts else ["None"]
|
| 439 |
+
|
| 440 |
custom_css = """
|
| 441 |
#label-gallery .gallery-item img {
|
| 442 |
height: 43px; /* 32pt ≈ 43px */
|
|
|
|
| 457 |
|
| 458 |
# --------- Updated Gradio UI with new tab --------- #
|
| 459 |
with gr.Blocks(css=custom_css) as demo:
|
| 460 |
+
with gr.Tab("【Train OCR Model】"):
|
| 461 |
font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"])
|
| 462 |
epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs")
|
| 463 |
lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate")
|
|
|
|
| 466 |
|
| 467 |
train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status)
|
| 468 |
|
| 469 |
+
with gr.Tab("【Generate Labels】"):
|
| 470 |
+
font_file_labels = gr.Dropdown(
|
| 471 |
+
choices=list_fonts(),
|
| 472 |
+
label="Optional font for label image",
|
| 473 |
+
interactive=True,
|
| 474 |
+
)
|
| 475 |
+
num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True)
|
| 476 |
+
gen_button = gr.Button("Generate Label Grid")
|
| 477 |
+
|
| 478 |
+
gen_button.click(
|
| 479 |
+
fn=generate_labels,
|
| 480 |
+
inputs=[font_file_labels, num_labels],
|
| 481 |
+
outputs=gr.Gallery(
|
| 482 |
+
label="Generated Labels",
|
| 483 |
+
columns=16, # 16 tiles per row
|
| 484 |
+
object_fit="contain", # Maintain aspect ratio
|
| 485 |
+
height="100%", # Allow full app height
|
| 486 |
+
elem_id="label-gallery" # For CSS targeting
|
| 487 |
+
)
|
| 488 |
|
| 489 |
+
)
|
| 490 |
+
with gr.Tab("【Recognize Text】"):
|
| 491 |
model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model")
|
| 492 |
refresh_btn = gr.Button("🔄 Refresh Models")
|
| 493 |
load_model_btn = gr.Button("Load Model") # <-- new button
|
|
|
|
| 505 |
|
| 506 |
predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text)
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
|
|
|
| 509 |
|
| 510 |
|
| 511 |
|