testna / app.py
eoeooe's picture
Update app.py
c3c1d28 verified
import os
import string
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
from torchvision import transforms
# ---------- Config ----------
CHARS = string.digits + string.ascii_uppercase
NUM_CLASSES = len(CHARS)
CAPTCHA_LEN = 6
IMG_W, IMG_H = 200, 80 # ต้องตรงกับ Colab ที่เทรน
DEVICE = torch.device("cpu")
CHAR2IDX = {c: i for i, c in enumerate(CHARS)}
IDX2CHAR = {i: c for c, i in CHAR2IDX.items()}
# ---------- Model ----------
class CaptchaNet(nn.Module):
def __init__(self):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
with torch.no_grad():
dummy = torch.zeros(1, 1, IMG_H, IMG_W)
d = self.cnn(dummy).flatten(1).shape[1]
print(f"Flatten size: {d}")
self.fc = nn.Sequential(
nn.Linear(d, 512), nn.ReLU(), nn.Dropout(0.4),
nn.Linear(512, CAPTCHA_LEN * NUM_CLASSES)
)
def forward(self, x):
x = self.cnn(x).flatten(1)
x = self.fc(x)
return x.view(-1, CAPTCHA_LEN, NUM_CLASSES)
# โหลดโมเดล
model = CaptchaNet().to(DEVICE)
if os.path.exists("model.pt"):
model.load_state_dict(torch.load("model.pt", map_location=DEVICE, weights_only=True))
print("✅ โหลดโมเดลสำเร็จ")
else:
print("⚠️ ไม่พบไฟล์ model.pt")
model.eval()
# Transform
tf = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((IMG_H, IMG_W)),
transforms.ToTensor(),
])
def predict_captcha(image: Image.Image):
if image is None:
return "กรุณาอัปโหลดรูปภาพ CAPTCHA 6 หลัก"
x = tf(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = model(x).argmax(-1)[0].cpu().tolist()
result = "".join(IDX2CHAR.get(i, "?") for i in out)
return result
# Gradio UI (เวอร์ชันใหม่)
demo = gr.Interface(
fn=predict_captcha,
inputs=gr.Image(type="pil", label="📤 อัปโหลดรูป CAPTCHA 6 หลัก (0-9 และ A-Z)"),
outputs=gr.Textbox(label="🔠 ผลลัพธ์ที่โมเดลทำนาย"),
title="🛡️ CAPTCHA Solver 6 หลัก",
description="โมเดลเทรนด้วยภาพสังเคราะห์ 6 ตัวอักษร\nอัปโหลดรูปแล้วกด Submit",
# ไม่ใส่ theme และ allow_flagging ที่นี่
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)