import os import string import torch import torch.nn as nn import gradio as gr from PIL import Image from torchvision import transforms # ---------- Config ---------- CHARS = string.digits + string.ascii_uppercase NUM_CLASSES = len(CHARS) CAPTCHA_LEN = 6 IMG_W, IMG_H = 200, 80 # ต้องตรงกับ Colab ที่เทรน DEVICE = torch.device("cpu") CHAR2IDX = {c: i for i, c in enumerate(CHARS)} IDX2CHAR = {i: c for c, i in CHAR2IDX.items()} # ---------- Model ---------- class CaptchaNet(nn.Module): def __init__(self): super().__init__() self.cnn = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), ) with torch.no_grad(): dummy = torch.zeros(1, 1, IMG_H, IMG_W) d = self.cnn(dummy).flatten(1).shape[1] print(f"Flatten size: {d}") self.fc = nn.Sequential( nn.Linear(d, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, CAPTCHA_LEN * NUM_CLASSES) ) def forward(self, x): x = self.cnn(x).flatten(1) x = self.fc(x) return x.view(-1, CAPTCHA_LEN, NUM_CLASSES) # โหลดโมเดล model = CaptchaNet().to(DEVICE) if os.path.exists("model.pt"): model.load_state_dict(torch.load("model.pt", map_location=DEVICE, weights_only=True)) print("✅ โหลดโมเดลสำเร็จ") else: print("⚠️ ไม่พบไฟล์ model.pt") model.eval() # Transform tf = transforms.Compose([ transforms.Grayscale(), transforms.Resize((IMG_H, IMG_W)), transforms.ToTensor(), ]) def predict_captcha(image: Image.Image): if image is None: return "กรุณาอัปโหลดรูปภาพ CAPTCHA 6 หลัก" x = tf(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = model(x).argmax(-1)[0].cpu().tolist() result = "".join(IDX2CHAR.get(i, "?") for i in out) return result # Gradio UI (เวอร์ชันใหม่) demo = gr.Interface( fn=predict_captcha, inputs=gr.Image(type="pil", label="📤 อัปโหลดรูป CAPTCHA 6 หลัก (0-9 และ A-Z)"), outputs=gr.Textbox(label="🔠 ผลลัพธ์ที่โมเดลทำนาย"), title="🛡️ CAPTCHA Solver 6 หลัก", description="โมเดลเทรนด้วยภาพสังเคราะห์ 6 ตัวอักษร\nอัปโหลดรูปแล้วกด Submit", # ไม่ใส่ theme และ allow_flagging ที่นี่ ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )