Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation | |
| from PIL import Image | |
| # Cargar el modelo y el preprocesador | |
| device = torch.device("cpu") | |
| model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) | |
| model.eval() | |
| preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") | |
| # Funci贸n de consulta para Gradio | |
| def query_image(img): | |
| # Procesar la imagen con el preprocesador | |
| inputs = preprocessor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Obtener la m谩scara de segmentaci贸n (aseg煤rate de que esta l贸gica coincida con tu configuraci贸n) | |
| mask = torch.argmax(outputs.logits[0], dim=0).cpu().detach().numpy() | |
| # Crear una m谩scara binaria solo para la clase de "regla" (de acuerdo a tu c贸digo original) | |
| rule_class_id = 1 # ID de la clase "regla" | |
| rule_mask = (mask == rule_class_id).astype(np.uint8) | |
| # Crear una imagen RGB para visualizar la m谩scara | |
| mask_image = np.stack([rule_mask] * 3, axis=-1) | |
| return Image.fromarray((mask_image * 255).astype(np.uint8)) | |
| # Crear la interfaz Gradio | |
| demo = gr.Interface( | |
| query_image, | |
| inputs=[gr.Image()], | |
| outputs="image", | |
| title="Rule Segmentation Demo", | |
| description="Please upload an image to see rule segmentation", | |
| ) | |
| # Lanzar la interfaz Gradio | |
| demo.launch() | |