teeth / app.py
dounan1's picture
first commit
cc66a3d
import gradio as gr
import cv2
import numpy as np
import requests
import os
import tempfile
# --- Configuration for Model Files ---
# IMPORTANT: These files must be present in your Hugging Face Space repository.
# 1. Haar Cascade for face detection (comes with OpenCV)
FACE_CASCADE_PATH = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
# 2. LBF Facial Landmark Model (lbfmodel.yaml)
# You need to download 'lbfmodel.yaml' and add it to your Hugging Face Space.
# A common source is: https://github.com/kurnianggoro/GSOC2017/blob/master/data/lbfmodel.yaml
LBF_MODEL_PATH = "lbfmodel.yaml"
# --- Core Image Processing and AI Editing Function ---
def edit_image_with_ai(input_image_np_rgb, stability_api_key, prompt_text="perfect white teeth"):
"""
Processes an input image to detect a mouth, then uses Stability AI to edit it.
Args:
input_image_np_rgb (numpy.ndarray): The input image in RGB format.
stability_api_key (str): The Stability AI API key.
prompt_text (str): The prompt for the AI image editing.
Returns:
tuple: (original_image, edited_image, status_message)
original_image is the input image.
edited_image is the AI-edited image (or None on failure).
status_message provides feedback on the process.
"""
if input_image_np_rgb is None:
return None, None, "Please upload an image or use the webcam."
if not stability_api_key:
return input_image_np_rgb, None, "⚠️ Please enter your Stability AI API Key."
if not prompt_text:
return input_image_np_rgb, None, "⚠️ Please enter a prompt for the edit."
try:
# Convert RGB (from Gradio) to BGR (for OpenCV)
img_bgr = cv2.cvtColor(input_image_np_rgb, cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
mouth_mask = np.zeros_like(gray) # Initialize an empty (black) mask
# --- Load Face Detector and Facemark Model ---
face_cascade = cv2.CascadeClassifier(FACE_CASCADE_PATH)
if face_cascade.empty():
# This error is unlikely if OpenCV is correctly installed.
raise gr.Error("Fatal Error: Could not load face cascade model. Check OpenCV installation.")
if not os.path.exists(LBF_MODEL_PATH):
raise gr.Error(
f"Error: LBF Facial Landmark Model ('{LBF_MODEL_PATH}') not found. "
"Please download it and add it to your Hugging Face Space's root directory."
)
facemark = cv2.face.createFacemarkLBF()
try:
facemark.loadModel(LBF_MODEL_PATH)
except cv2.error as e:
raise gr.Error(f"Error loading LBF model from '{LBF_MODEL_PATH}': {e}. Ensure the file is valid.")
# --- Detect Faces ---
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(50, 50)) # Increased minSize
if len(faces) == 0:
return input_image_np_rgb, None, "No faces detected in the image. Try a clearer image or different pose."
# --- Fit Landmarks and Create Mouth Mask ---
ok, landmarks_all_faces = facemark.fit(img_bgr, faces)
num_mouths_masked = 0
if ok:
for landmarks_one_face in landmarks_all_faces:
shape = landmarks_one_face[0].astype(np.int32)
# Mouth points (indices 48-59 for outer mouth in 68-point model)
mouth_points = shape[48:60]
if len(mouth_points) > 2: # Need at least 3 points for a hull
mouth_hull = cv2.convexHull(mouth_points)
# Dilate the mask slightly to ensure coverage
cv2.drawContours(mouth_mask, [mouth_hull], -1, (255), -1)
num_mouths_masked += 1
if num_mouths_masked > 0:
# Dilate the final mask slightly to ensure edges are included for inpainting
kernel = np.ones((5,5),np.uint8)
mouth_mask = cv2.dilate(mouth_mask, kernel, iterations = 1)
else:
return input_image_np_rgb, None, "Facial landmark detection failed. Could not identify facial features."
if num_mouths_masked == 0:
return input_image_np_rgb, None, "Could not identify or create a mask for any mouth regions."
_, mouth_mask_binary = cv2.threshold(mouth_mask, 127, 255, cv2.THRESH_BINARY)
# --- Call Stability AI API for Inpainting ---
# Use temporary files for sending image and mask to the API
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_image_file, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_mask_file:
cv2.imwrite(tmp_image_file.name, img_bgr) # OpenCV saves in BGR
cv2.imwrite(tmp_mask_file.name, mouth_mask_binary)
tmp_image_path = tmp_image_file.name
tmp_mask_path = tmp_mask_file.name
edited_image_np_rgb = None # Initialize
status_message = "Calling Stability AI API..."
try:
response = requests.post(
"https://api.stability.ai/v2beta/stable-image/edit/inpaint",
headers={
"authorization": f"Bearer {stability_api_key}",
"accept": "image/*" # Expecting an image in response
},
files={
"image": open(tmp_image_path, "rb"),
"mask": open(tmp_mask_path, "rb"),
},
data={
"prompt": prompt_text,
"output_format": "png", # Request PNG output
# You can add other parameters like 'seed', 'strength' if desired
# e.g., "strength": 0.75,
},
timeout=90 # Increased timeout for potentially slow API responses
)
if response.status_code == 200:
edited_image_bytes = response.content
nparr = np.frombuffer(edited_image_bytes, np.uint8)
# Decode image, OpenCV reads as BGR
edited_image_cv_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if edited_image_cv_bgr is not None:
# Convert BGR to RGB for Gradio display
edited_image_np_rgb = cv2.cvtColor(edited_image_cv_bgr, cv2.COLOR_BGR2RGB)
status_message = "✅ Image edited successfully!"
else:
status_message = "⚠️ Stability API returned an invalid image format."
return input_image_np_rgb, None, status_message
else:
# Try to parse JSON error, otherwise use text
try:
error_detail = response.json()
# Stability AI errors are often in a list under 'errors'
error_messages = error_detail.get('errors', [response.text])
status_message = f"⚠️ Stability API Error ({response.status_code}): {error_messages[0]}"
except requests.exceptions.JSONDecodeError:
status_message = f"⚠️ Stability API Error ({response.status_code}): {response.text}"
return input_image_np_rgb, None, status_message
except requests.exceptions.Timeout:
status_message = "⚠️ Stability API request timed out. The server might be busy or the image too large. Please try again later."
return input_image_np_rgb, None, status_message
except requests.exceptions.RequestException as e:
status_message = f"⚠️ Request to Stability API failed: {e}"
return input_image_np_rgb, None, status_message
except Exception as e:
status_message = f"⚠️ An unexpected error occurred during API processing: {e}"
return input_image_np_rgb, None, status_message
finally:
os.remove(tmp_image_path) # Clean up temporary file
os.remove(tmp_mask_path) # Clean up temporary file
return input_image_np_rgb, edited_image_np_rgb, status_message
except gr.Error as e: # Catch Gradio specific errors raised intentionally
return input_image_np_rgb, None, str(e) # Display the gr.Error message
except Exception as e:
# Catch any other unexpected errors during CV processing
return input_image_np_rgb, None, f"⚠️ An unexpected error occurred in processing: {e}"
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
gr.Markdown(
"<h1 style='text-align: center;'>👄 Your Best Smile With Keen Dental 😄 </h1>"
"<p style='text-align: center;'>Upload an image or use your webcam. The app detects faces, "
" and uses AI to change your teeth based on your prompt.</p>"
)
gr.Markdown(
"**Instructions:**\n"
"1. Enter your Stability AI API key (e.g., `sk-xxxxxxxx`). Get one from [Stability AI Developer Platform](https://platform.stability.ai/).\n"
"2. Upload an image or use your webcam.\n"
"3. Click 'Smile!'\n\n"
)
with gr.Row():
api_key_input = gr.Textbox(
label="🔑 Stability AI API Key",
placeholder="Enter your sk-xxxxxxxx API key here",
type="password",
lines=1,
elem_id="api_key_input"
)
with gr.Row():
input_image = gr.Image(label="🖼️ Input Image", type="numpy", sources=["upload", "webcam"], height=400, interactive=True)
output_image = gr.Image(label="✨ Edited Image", type="numpy", height=400, interactive=False)
process_button = gr.Button("🚀 Smile!", variant="primary")
status_output = gr.Textbox(label="ℹ️ Status", interactive=False, lines=2)
process_button.click(
fn=edit_image_with_ai,
inputs=[input_image, api_key_input],
outputs=[input_image, output_image, status_output], # Display original again for reference
api_name="edit_mouth" # For API access if needed
)
# To run this locally for testing (optional):
if __name__ == "__main__":
demo.launch(debug=True)
# For Hugging Face Spaces, save this as 'app.py'.
# Gradio will automatically call demo.launch().