Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import gradio as gr | |
| # Kiểm tra thiết bị | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f'Sử dụng thiết bị: {device}') | |
| # Định nghĩa kiến trúc mô hình phân loại | |
| class ClassificationModel(nn.Module): | |
| def __init__(self, num_classes=12): | |
| super(ClassificationModel, self).__init__() | |
| self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Sử dụng pretrained weights | |
| num_ftrs = self.model.fc.in_features | |
| self.model.fc = nn.Linear(num_ftrs, num_classes) # Thay đổi lớp cuối cùng | |
| def forward(self, x): | |
| return self.model(x) | |
| # Khởi tạo và tải mô hình phân loại | |
| classification_model = ClassificationModel(num_classes=12) | |
| classification_model.load_state_dict(torch.load('classification_state_dict.pt', map_location=device)) | |
| classification_model.to(device) | |
| classification_model.eval() | |
| # Tải mô hình YOLO | |
| yolo_model = YOLO('best.pt') # Đảm bảo rằng 'best.pt' nằm trong thư mục hiện tại | |
| # Định nghĩa các lớp mụn | |
| class_labels = [ | |
| 'acne_scars', 'blackhead', 'cystic', 'flat_wart', 'folliculitis', | |
| 'keloid', 'milium', 'papular', 'purulent', 'sebo-crystan-conglo', | |
| 'syringoma', 'whitehead' | |
| ] | |
| # Ánh xạ tiếng Anh -> tiếng Việt | |
| class_mapping = { | |
| 'acne_scars': 'Sẹo mụn', | |
| 'blackhead': 'Mụn đầu đen', | |
| 'cystic': 'Mụn nang', | |
| 'flat_wart': 'Mụn sần phẳng', | |
| 'folliculitis': 'Viêm nang lông', | |
| 'keloid': 'Mụn sẹo uốn', | |
| 'milium': 'Mụn mili', | |
| 'papular': 'Mụn nhỏ', | |
| 'purulent': 'Mụn mủ', | |
| 'sebo-crystan-conglo': 'Mụn bã đen kết tủa', | |
| 'syringoma': 'Mụn nang mồ hôi', | |
| 'whitehead': 'Mụn đầu trắng' | |
| } | |
| # Định nghĩa các biến đổi dữ liệu cho mô hình phân loại | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| def detect_and_classify(image, image_size=640, conf_threshold=0.4, iou_threshold=0.5): | |
| """ | |
| Hàm này nhận vào một ảnh, phát hiện các vùng mụn bằng YOLO, | |
| phân loại từng vùng mụn bằng mô hình ResNet18, và trả về ảnh đã được | |
| annotate cùng với các thông tin liên quan. | |
| """ | |
| # Mở ảnh và chuyển đổi sang RGB | |
| pil_image = Image.open(image).convert("RGB") | |
| # Dự đoán bằng YOLO | |
| results = yolo_model.predict(pil_image, conf=conf_threshold, iou=iou_threshold, imgsz=image_size) | |
| boxes = results[0].boxes | |
| num_boxes = len(boxes) | |
| if num_boxes == 0: | |
| severity = "Tốt" | |
| recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da." | |
| return pil_image, f"Tình trạng mụn: {severity}", recommendation, "Không phát hiện mụn." | |
| # Lấy thông tin bounding boxes | |
| xyxy = boxes.xyxy.detach().cpu().numpy().astype(int) # Toạ độ bounding box | |
| confidences = boxes.conf.detach().cpu().numpy() | |
| class_ids = boxes.cls.detach().cpu().numpy().astype(int) | |
| # Chuẩn bị vẽ | |
| draw = ImageDraw.Draw(pil_image) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 15) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| # Danh sách để lưu kết quả phân loại | |
| classified_results = [] | |
| for i, (box, cls_id, conf) in enumerate(zip(xyxy, class_ids, confidences), start=1): | |
| x1, y1, x2, y2 = box | |
| class_name_en = class_labels[cls_id] | |
| class_name_vn = class_mapping.get(class_name_en, class_name_en) | |
| # Cắt crop vùng mụn | |
| crop = pil_image.crop((x1, y1, x2, y2)) | |
| img_transformed = transform(crop).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = classification_model(img_transformed) | |
| probabilities = torch.softmax(output, dim=1) | |
| top_prob, top_class = probabilities.topk(1, dim=1) | |
| top_prob = top_prob.item() | |
| top_class = top_class.item() | |
| class_name_en = class_labels[top_class] | |
| class_name_vn = class_mapping.get(class_name_en, class_name_en) | |
| # Vẽ bounding box và nhãn | |
| label = f"#{i}: {class_name_en} ({class_name_vn}) ({top_prob:.2f})" | |
| # Sử dụng textbbox thay vì textsize | |
| text_bbox = draw.textbbox((0, 0), label, font=font) | |
| text_w = text_bbox[2] - text_bbox[0] | |
| text_h = text_bbox[3] - text_bbox[1] | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=2) | |
| draw.rectangle([x1, y1 - text_h, x1 + text_w, y1], fill="red") | |
| draw.text((x1, y1 - text_h), label, fill="white", font=font) | |
| # Thêm kết quả phân loại vào danh sách | |
| classified_results.append((i, class_name_en, class_name_vn)) | |
| # Đánh giá tình trạng da dựa trên số lượng mụn | |
| if num_boxes > 20: | |
| severity = "Nặng" | |
| recommendation = "Bạn nên đến gặp bác sĩ da liễu và sử dụng liệu trình trị mụn chuyên sâu." | |
| elif 10 <= num_boxes <= 20: | |
| severity = "Trung bình" | |
| recommendation = "Duy trì skincare đều đặn với sữa rửa mặt dịu nhẹ và dưỡng ẩm." | |
| else: | |
| severity = "Tốt" | |
| recommendation = "Làn da bạn khá ổn! Tiếp tục duy trì thói quen chăm sóc da." | |
| # Liệt kê loại mụn | |
| acne_types_str = "Danh sách mụn phát hiện:\n" | |
| for idx, cname_en, cname_vn in classified_results: | |
| acne_types_str += f"Mụn #{idx}: {cname_en} ({cname_vn})\n" | |
| return pil_image, f"Tình trạng mụn: {severity}", recommendation, acne_types_str | |
| # Mô tả ứng dụng | |
| description_md = """ | |
| ## Ứng Dụng Nhận Diện và Phân Loại Mụn Bằng YOLO và ResNet18 | |
| 1. **Phát hiện mụn:** Sử dụng mô hình YOLO để phát hiện các vùng mụn trên khuôn mặt. | |
| 2. **Phân loại mụn:** Sử dụng mô hình ResNet18 đã được huấn luyện để phân loại từng vùng mụn thành 12 loại khác nhau bao gồm: Sẹo mụn trứng cá, Mụn đầu đen, Mụn nang, Mụn phẳng, Viêm nang lông, Sẹo mùm, Mụn mili, Mụn sần nhỏ, Mụn mủ, Mụn bã đen kết tủa, Mụn nang mồ hôi, Mụn đầu trắng | |
| 3. **Hiển thị kết quả:** Ảnh sau khi xử lý sẽ hiển thị các bounding boxes, nhãn tiếng Anh và tiếng Việt của loại mụn, cùng với độ chính xác của mỗi phân loại. | |
| 4. **Đánh giá tình trạng da:** Cung cấp đánh giá tổng quát về tình trạng da và khuyến nghị tương ứng dựa trên số lượng mụn được phát hiện. | |
| """ | |
| # Định nghĩa giao diện Gradio | |
| inputs = [ | |
| gr.Image(type="filepath", label="Ảnh Khuôn Mặt"), | |
| gr.Slider(minimum=320, maximum=1280, step=32, value=640, label="Kích thước ảnh (Image Size)"), | |
| gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Ngưỡng Confidence"), | |
| gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Ngưỡng IOU") | |
| ] | |
| outputs = [ | |
| gr.Image(type="pil", label="Ảnh Sau Khi Xử Lý"), | |
| gr.Textbox(label="Tình Trạng Mụn"), | |
| gr.Textbox(label="Khuyến Nghị"), | |
| gr.Textbox(label="Loại Mụn Phát Hiện") | |
| ] | |
| # Tạo giao diện Gradio | |
| app = gr.Interface( | |
| fn=detect_and_classify, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="YOLO + ResNet18 Phát Hiện và Phân Loại Mụn", | |
| description=description_md | |
| ) | |
| # Khởi chạy ứng dụng | |
| app.launch(share=True) | |