ToothXRay / app.py
SalmanAboAraj's picture
Update app.py
ae56d88 verified
import numpy as np
import cv2
import base64
import re
import io
from PIL import Image, ImageColor
import gradio as gr
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
try:
model_path = hf_hub_download(repo_id="SalmanAboAraj/dental", filename="open_source_augmentation_1.keras")
model = load_model(model_path)
except Exception as e:
print(f"Error loading model: {e}")
model = None
def get_sparse_predictions(predictions):
return np.argmax(predictions, axis=-1)
def get_annotated_images(sparse_predictions, labels_colors, labels_numbers):
annotated_images = np.zeros((*sparse_predictions.shape, 3), dtype=np.uint8)
for label, number in labels_numbers.items():
if label in labels_colors:
color = np.array(labels_colors[label], dtype=np.uint8)
mask = sparse_predictions == number
annotated_images[mask] = color
return annotated_images
def load_labels_colors(labels_path):
labels_colors = {'background': (0, 0, 0)}
pattern_1 = re.compile(r'(?P<label_name>[\w\s\-]+):(?P<red>\d+),(?P<green>\d+),(?P<blue>\d+)::')
pattern_2 = re.compile(r'(?P<label_name>[\w\s\-]+):(?P<color>#([\d\w]{6}))')
try:
with open(labels_path, 'r') as file:
for line in file:
line = line.strip()
if not line:
continue
match_1 = pattern_1.match(line)
match_2 = pattern_2.match(line)
if match_1:
label_name = match_1.group('label_name').lower()
red, green, blue = int(match_1.group('red')), int(match_1.group('green')), int(match_1.group('blue'))
elif match_2:
label_name = match_2.group('label_name').lower()
red, green, blue = ImageColor.getcolor(match_2.group('color'), "RGB")
else:
continue
labels_colors[label_name] = (red, green, blue)
except FileNotFoundError:
print("Error: labelmap.txt not found.")
return None
except Exception as e:
print(f"Error reading labelmap.txt: {e}")
return None
return labels_colors
def get_labels_numbers(labels_colors):
return {label: i for i, label in enumerate(labels_colors.keys())}
def predict(image_base64):
if model is None:
return "Error: Model failed to load. Please check the model file."
labels_colors = load_labels_colors('labelmap.txt')
if labels_colors is None:
return "Error: labelmap.txt not found or could not be read."
labels_numbers = get_labels_numbers(labels_colors)
try:
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
image_np = np.array(image)
orig_height, orig_width = image_np.shape[:2]
except Exception as e:
return f"Error processing image: {e}"
try:
image_resized = cv2.resize(image_np, (256, 256))
image_resized = np.expand_dims(image_resized, axis=0)
mask_pred = model.predict(image_resized)[0]
sparse_predictions = get_sparse_predictions(mask_pred)
annotated_predictions = get_annotated_images(sparse_predictions, labels_colors, labels_numbers)
mask_resized = cv2.resize(annotated_predictions, (orig_width, orig_height), interpolation=cv2.INTER_NEAREST)
if mask_resized.shape[-1] == 3:
alpha = 0.5
mask_resized = np.concatenate((mask_resized, np.full((*mask_resized.shape[:2], 1), int(255 * alpha), dtype=np.uint8)), axis=-1)
blended_image = image_np.copy()
plt.imshow(blended_image)
plt.imshow(mask_resized)
plt.axis("off")
fig = plt.gcf()
fig.canvas.draw()
blended_image = np.array(fig.canvas.renderer.buffer_rgba())
_, buffer = cv2.imencode(".png", cv2.cvtColor(blended_image, cv2.COLOR_RGBA2BGRA))
base64_str = base64.b64encode(buffer).decode("utf-8")
return base64_str
except Exception as e:
return f"Error during prediction: {e}"
iface = gr.Interface(
fn=predict,
inputs="text",
outputs="text",
title="Dental Image Segmentation",
description="Upload a base64 image and get the original and segmented image in base64 format.",
)
iface.launch()