import numpy as np import cv2 import base64 import re import io from PIL import Image, ImageColor import gradio as gr from huggingface_hub import hf_hub_download from tensorflow.keras.models import load_model import matplotlib.pyplot as plt try: model_path = hf_hub_download(repo_id="SalmanAboAraj/dental", filename="open_source_augmentation_1.keras") model = load_model(model_path) except Exception as e: print(f"Error loading model: {e}") model = None def get_sparse_predictions(predictions): return np.argmax(predictions, axis=-1) def get_annotated_images(sparse_predictions, labels_colors, labels_numbers): annotated_images = np.zeros((*sparse_predictions.shape, 3), dtype=np.uint8) for label, number in labels_numbers.items(): if label in labels_colors: color = np.array(labels_colors[label], dtype=np.uint8) mask = sparse_predictions == number annotated_images[mask] = color return annotated_images def load_labels_colors(labels_path): labels_colors = {'background': (0, 0, 0)} pattern_1 = re.compile(r'(?P[\w\s\-]+):(?P\d+),(?P\d+),(?P\d+)::') pattern_2 = re.compile(r'(?P[\w\s\-]+):(?P#([\d\w]{6}))') try: with open(labels_path, 'r') as file: for line in file: line = line.strip() if not line: continue match_1 = pattern_1.match(line) match_2 = pattern_2.match(line) if match_1: label_name = match_1.group('label_name').lower() red, green, blue = int(match_1.group('red')), int(match_1.group('green')), int(match_1.group('blue')) elif match_2: label_name = match_2.group('label_name').lower() red, green, blue = ImageColor.getcolor(match_2.group('color'), "RGB") else: continue labels_colors[label_name] = (red, green, blue) except FileNotFoundError: print("Error: labelmap.txt not found.") return None except Exception as e: print(f"Error reading labelmap.txt: {e}") return None return labels_colors def get_labels_numbers(labels_colors): return {label: i for i, label in enumerate(labels_colors.keys())} def predict(image_base64): if model is None: return "Error: Model failed to load. Please check the model file." labels_colors = load_labels_colors('labelmap.txt') if labels_colors is None: return "Error: labelmap.txt not found or could not be read." labels_numbers = get_labels_numbers(labels_colors) try: image_data = base64.b64decode(image_base64) image = Image.open(io.BytesIO(image_data)).convert("RGB") image_np = np.array(image) orig_height, orig_width = image_np.shape[:2] except Exception as e: return f"Error processing image: {e}" try: image_resized = cv2.resize(image_np, (256, 256)) image_resized = np.expand_dims(image_resized, axis=0) mask_pred = model.predict(image_resized)[0] sparse_predictions = get_sparse_predictions(mask_pred) annotated_predictions = get_annotated_images(sparse_predictions, labels_colors, labels_numbers) mask_resized = cv2.resize(annotated_predictions, (orig_width, orig_height), interpolation=cv2.INTER_NEAREST) if mask_resized.shape[-1] == 3: alpha = 0.5 mask_resized = np.concatenate((mask_resized, np.full((*mask_resized.shape[:2], 1), int(255 * alpha), dtype=np.uint8)), axis=-1) blended_image = image_np.copy() plt.imshow(blended_image) plt.imshow(mask_resized) plt.axis("off") fig = plt.gcf() fig.canvas.draw() blended_image = np.array(fig.canvas.renderer.buffer_rgba()) _, buffer = cv2.imencode(".png", cv2.cvtColor(blended_image, cv2.COLOR_RGBA2BGRA)) base64_str = base64.b64encode(buffer).decode("utf-8") return base64_str except Exception as e: return f"Error during prediction: {e}" iface = gr.Interface( fn=predict, inputs="text", outputs="text", title="Dental Image Segmentation", description="Upload a base64 image and get the original and segmented image in base64 format.", ) iface.launch()