import gradio as gr import torch from transformers import SwinForImageClassification, AutoFeatureExtractor import mediapipe as mp import cv2 from PIL import Image import numpy as np import os import time # ----------------------------- # 1. Face shape descriptions # ----------------------------- face_shape_descriptions = { "Heart": "dengan dahi lebar dan dagu yang runcing.", "Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.", "Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.", "Round": "dengan garis rahang melengkung dan pipi penuh.", "Square": "dengan rahang tegas dan dahi lebar." } # ----------------------------- # 2. Glasses images (frames) # ----------------------------- glasses_images = { "Oval": "glasses/oval.jpg", "Round": "glasses/round.jpg", "Square": "glasses/square.jpg", "Octagon": "glasses/octagon.jpg", "Cat Eye": "glasses/cat eye.jpg", "Pilot (Aviator)": "glasses/aviator.jpg" } if not os.path.exists("glasses"): os.makedirs("glasses") for _, path in glasses_images.items(): if not os.path.exists(path): dummy_image = Image.new('RGB', (200, 100), color='gray') dummy_image.save(path) # ----------------------------- # 3. Label mappings # ----------------------------- id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'} label2id = {v: k for k, v in id2label.items()} # ----------------------------- # 4. Load Model # ----------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_checkpoint = "microsoft/swin-tiny-patch4-window7-224" feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) model = SwinForImageClassification.from_pretrained( model_checkpoint, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True ) # Load trained weights (optional) if os.path.exists('LR-0001-adamW-32-64swin.pth'): state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device) model.load_state_dict(state_dict, strict=False) print("✅ Trained weights loaded") else: print("⚠️ Warning: 'LR-0001-adamW-32-64swin.pth' not found, using base pretrained weights") model.to(device) model.eval() # ----------------------------- # 5. Mediapipe # ----------------------------- mp_face_detection = mp.solutions.face_detection.FaceDetection( model_selection=1, min_detection_confidence=0.5 ) # ----------------------------- # 6. Rule-based glasses recommendation # ----------------------------- def recommend_glasses_tree(face_shape): face_shape = face_shape.lower() if face_shape == "square": return ["Oval", "Round"] elif face_shape == "round": return ["Square", "Octagon", "Cat Eye"] elif face_shape == "oval": return ["Pilot (Aviator)", "Cat Eye", "Round"] elif face_shape == "heart": return ["Oval", "Round", "Cat Eye", "Pilot (Aviator)"] elif face_shape == "oblong": return ["Square", "Pilot (Aviator)", "Cat Eye"] else: return [] # ----------------------------- # 7. Preprocess image # ----------------------------- def preprocess_image(image): img = np.array(image, dtype=np.uint8) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) if results.detections: detection = results.detections[0] bbox = detection.location_data.relative_bounding_box h, w, _ = img.shape x1 = max(int(bbox.xmin * w), 0) y1 = max(int(bbox.ymin * h), 0) x2 = min(int((bbox.xmin + bbox.width) * w), w) y2 = min(int((bbox.ymin + bbox.height) * h), h) if x2 > x1 and y2 > y1: img = img[y1:y2, x1:x2] else: return None else: return None img = cv2.resize(img, (224, 224)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) inputs = feature_extractor(images=img, return_tensors="pt") return inputs['pixel_values'].squeeze(0) # ----------------------------- # 8. Prediction function # ----------------------------- def predict(image): start = time.perf_counter() try: inputs = preprocess_image(image) if inputs is None: elapsed_ms = (time.perf_counter() - start) * 1000 return "Unknown", "Wajah tidak terdeteksi.", [], f"{elapsed_ms:.2f} ms" inputs = inputs.unsqueeze(0).to(device) with torch.no_grad(): outputs = model(inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1) pred_idx = torch.argmax(probs, dim=1).item() pred_label = id2label[pred_idx] pred_prob = probs[0][pred_idx].item() * 100 frame_recommendations = recommend_glasses_tree(pred_label) description = face_shape_descriptions.get(pred_label, "tidak dikenali") gallery_items = [] for frame in frame_recommendations: frame_image_path = glasses_images.get(frame) if frame_image_path and os.path.exists(frame_image_path): try: frame_image = Image.open(frame_image_path) gallery_items.append((frame_image, frame)) except Exception as e: print(f"Error loading image for {frame}: {e}") if frame_recommendations: recommended_frames_text = ', '.join(frame_recommendations) explanation = ( f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). " f"Kamu memiliki bentuk wajah {description} " f"Rekomendasi kacamata: {recommended_frames_text}." ) else: explanation = ( f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). " f"Tidak ada rekomendasi frame." ) elapsed_ms = (time.perf_counter() - start) * 1000 return pred_label, explanation, gallery_items, f"{elapsed_ms:.2f} ms" except Exception as e: elapsed_ms = (time.perf_counter() - start) * 1000 return "Error", f"Terjadi kesalahan: {str(e)}", [], f"{elapsed_ms:.2f} ms" # ----------------------------- # 9. Gradio UI # ----------------------------- with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown("# Program Rekomendasi Bentuk Kacamata") gr.Markdown("Pastikan foto yang diunggah dapat terlihat jelas bagian wajah. Pastikan hanya menampilkan satu orang atau wajah untuk satu proses deteksi") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Foto Wajah") confirm_button = gr.Button("Konfirmasi") restart_button = gr.Button("Restart") with gr.Column(): detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi") explanation_output = gr.Textbox(label="Penjelasan") recommendation_gallery = gr.Gallery( label="Rekomendasi Kacamata", columns=3, show_label=False ) time_output = gr.Textbox(label="Inference Time (ms)", interactive=False) confirm_button.click( predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery, time_output] ) restart_button.click( lambda: (None, "", "", [], ""), inputs=None, outputs=[image_input, detected_shape, explanation_output, recommendation_gallery, time_output] ) gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)") if __name__ == "__main__": iface.launch()