Hello / app.py
SalmanAboAraj's picture
Update app.py
c78deec verified
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
# Download and load the model
model_path = hf_hub_download(repo_id="SalmanAboAraj/Tooth1", filename="unet_model_256.h5")
model = load_model(model_path)
# Define color map for categories
COLOR_MAP = {
0: (0, 0, 0), # Background
1: (215, 179, 255), # Bone
2: (246, 51, 81), # Cavity
3: (58, 132, 255), # Crown
4: (134, 202, 218), # Dental Implant
5: (221, 195, 130), # Dental Implant Crown
6: (255, 255, 127), # Dentin
7: (255, 255, 255), # Enamel
8: (1, 13, 27), # Filling Metal
9: (0, 133, 255), # Filling Non-Metal
10: (24, 250, 143), # Periapical Radiolucence
11: (255, 105, 248), # Pulp
12: (17, 253, 231), # Root Canal
13: (255, 146, 119), # Sinus
14: (131, 224, 112) # Missing
}
def resize_and_predict(image_base64):
# Decode the base64 image
image_data = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_data))
# Resize the image to 512x512
resized_image = image.resize((512, 512))
# Convert the image to numpy array for prediction
image_np = np.array(resized_image)
original_height, original_width, _ = image_np.shape
# Resize the image to match the model input size
image_resized = cv2.resize(image_np, (256, 256))
image_resized = image_resized / 255.0
image_resized = np.expand_dims(image_resized, axis=0)
# Predict the mask
mask_pred = model.predict(image_resized)[0]
# Convert the mask to color image
mask_class = np.argmax(mask_pred, axis=-1)
mask_colored = np.zeros((256, 256, 3), dtype=np.uint8)
for class_idx, color in COLOR_MAP.items():
mask_colored[mask_class == class_idx] = color
# Resize the mask to original image size
mask_final = cv2.resize(mask_colored, (original_width, original_height), interpolation=cv2.INTER_NEAREST)
# Convert the mask back to base64
buffered = BytesIO()
mask_final_image = Image.fromarray(mask_final)
mask_final_image.save(buffered, format="PNG")
mask_final_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return mask_final_base64
# Create the Gradio interface
iface = gr.Interface(
fn=resize_and_predict,
inputs="text", # base64 input as text
outputs="text", # base64 output as text
title="Image Resizer and Predictor",
description="Upload a base64 image, get a resized 512x512 base64 image with predictions.",
)
# Launch the interface
iface.launch()