Spaces:
Build error
Build error
| import numpy as np | |
| import cv2 | |
| import base64 | |
| import re | |
| import io | |
| from PIL import Image, ImageColor | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from tensorflow.keras.models import load_model | |
| import matplotlib.pyplot as plt | |
| try: | |
| model_path = hf_hub_download(repo_id="SalmanAboAraj/dental", filename="open_source_augmentation_1.keras") | |
| model = load_model(model_path) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| def get_sparse_predictions(predictions): | |
| return np.argmax(predictions, axis=-1) | |
| def get_annotated_images(sparse_predictions, labels_colors, labels_numbers): | |
| annotated_images = np.zeros((*sparse_predictions.shape, 3), dtype=np.uint8) | |
| for label, number in labels_numbers.items(): | |
| if label in labels_colors: | |
| color = np.array(labels_colors[label], dtype=np.uint8) | |
| mask = sparse_predictions == number | |
| annotated_images[mask] = color | |
| return annotated_images | |
| def load_labels_colors(labels_path): | |
| labels_colors = {'background': (0, 0, 0)} | |
| pattern_1 = re.compile(r'(?P<label_name>[\w\s\-]+):(?P<red>\d+),(?P<green>\d+),(?P<blue>\d+)::') | |
| pattern_2 = re.compile(r'(?P<label_name>[\w\s\-]+):(?P<color>#([\d\w]{6}))') | |
| try: | |
| with open(labels_path, 'r') as file: | |
| for line in file: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| match_1 = pattern_1.match(line) | |
| match_2 = pattern_2.match(line) | |
| if match_1: | |
| label_name = match_1.group('label_name').lower() | |
| red, green, blue = int(match_1.group('red')), int(match_1.group('green')), int(match_1.group('blue')) | |
| elif match_2: | |
| label_name = match_2.group('label_name').lower() | |
| red, green, blue = ImageColor.getcolor(match_2.group('color'), "RGB") | |
| else: | |
| continue | |
| labels_colors[label_name] = (red, green, blue) | |
| except FileNotFoundError: | |
| print("Error: labelmap.txt not found.") | |
| return None | |
| except Exception as e: | |
| print(f"Error reading labelmap.txt: {e}") | |
| return None | |
| return labels_colors | |
| def get_labels_numbers(labels_colors): | |
| return {label: i for i, label in enumerate(labels_colors.keys())} | |
| def predict(image_base64): | |
| if model is None: | |
| return "Error: Model failed to load. Please check the model file." | |
| labels_colors = load_labels_colors('labelmap.txt') | |
| if labels_colors is None: | |
| return "Error: labelmap.txt not found or could not be read." | |
| labels_numbers = get_labels_numbers(labels_colors) | |
| try: | |
| image_data = base64.b64decode(image_base64) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| image_np = np.array(image) | |
| orig_height, orig_width = image_np.shape[:2] | |
| except Exception as e: | |
| return f"Error processing image: {e}" | |
| try: | |
| image_resized = cv2.resize(image_np, (256, 256)) | |
| image_resized = np.expand_dims(image_resized, axis=0) | |
| mask_pred = model.predict(image_resized)[0] | |
| sparse_predictions = get_sparse_predictions(mask_pred) | |
| annotated_predictions = get_annotated_images(sparse_predictions, labels_colors, labels_numbers) | |
| mask_resized = cv2.resize(annotated_predictions, (orig_width, orig_height), interpolation=cv2.INTER_NEAREST) | |
| if mask_resized.shape[-1] == 3: | |
| alpha = 0.5 | |
| mask_resized = np.concatenate((mask_resized, np.full((*mask_resized.shape[:2], 1), int(255 * alpha), dtype=np.uint8)), axis=-1) | |
| blended_image = image_np.copy() | |
| plt.imshow(blended_image) | |
| plt.imshow(mask_resized) | |
| plt.axis("off") | |
| fig = plt.gcf() | |
| fig.canvas.draw() | |
| blended_image = np.array(fig.canvas.renderer.buffer_rgba()) | |
| _, buffer = cv2.imencode(".png", cv2.cvtColor(blended_image, cv2.COLOR_RGBA2BGRA)) | |
| base64_str = base64.b64encode(buffer).decode("utf-8") | |
| return base64_str | |
| except Exception as e: | |
| return f"Error during prediction: {e}" | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs="text", | |
| outputs="text", | |
| title="Dental Image Segmentation", | |
| description="Upload a base64 image and get the original and segmented image in base64 format.", | |
| ) | |
| iface.launch() | |