import torch import torch.nn as nn from torchvision import models, transforms from huggingface_hub import hf_hub_download from PIL import Image import gradio as gr # ----------- Load model from HuggingFace Hub ----------- def load_model(): # Tải file .pth từ model repo bạn đã upload ckpt_path = hf_hub_download( repo_id="dat201204/resnet18-flood-detection-cvfd", # model repo của bạn filename="resnet18_tl_cvfd.pth" ) # Dựng lại kiến trúc giống lúc train model = models.resnet18(weights=None) model.fc = nn.Linear(model.fc.in_features, 2) state = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(state) model.eval() return model model = load_model() # THỨ TỰ CLASS phải giống lúc train (train_dataset.classes) # Nếu bạn dùng ImageFolder và folder là 'flooded', 'non-flooded' thì nó sẽ là như dưới: class_names = ["flooded", "non-flooded"] # ----------- Image transforms ----------- tfms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ----------- Prediction function ----------- def predict(img): img_tensor = tfms(img).unsqueeze(0) with torch.no_grad(): logits = model(img_tensor) probs = torch.softmax(logits, dim=1)[0].tolist() # Trả về dict: {label: probability} return {cls: float(p) for cls, p in zip(class_names, probs)} # ----------- Gradio UI ----------- demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), title="Flood Detection — ResNet18 (Transfer Learning, CVFD)", description="Upload ảnh đường / camera để phát hiện Flooded vs Non-flooded.", ) if __name__ == "__main__": demo.launch()