Spaces:
Sleeping
Sleeping
File size: 4,133 Bytes
845eb23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | import os
import cv2
import qrcode
import numpy as np
import gradio as gr
import mediapipe as mp
import tensorflow as tf
from PIL import Image, ImageOps
# --- 1. Load the Pre-trained TFLite Model ---
# We assume 'model.tflite' is uploaded alongside this script
interpreter = tf.lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()
# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("Interpreter is ready for edge predictions!")
# --- 2. Setup MediaPipe for Background Removal ---
mp_selfie_segmentation = mp.solutions.selfie_segmentation
segmenter = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
def remove_background(image_pil):
"""
Uses MediaPipe to segment the person from the background and replaces the background with white.
"""
# Convert PIL image to numpy array (RGB)
image_np = np.array(image_pil.convert("RGB"), copy=True)
# Get the segmentation mask
results = segmenter.process(image_np)
mask = results.segmentation_mask
# Create a condition: If probability > 0.5, it's a hand/person
condition = np.stack((mask,) * 3, axis=-1) > 0.5
# Create a solid white background
bg_image = np.ones(image_np.shape, dtype=np.uint8) * 255
# Combine: Keep original if condition is true, else use white background
output_image = np.where(condition, image_np, bg_image)
# Return as PIL Image
return Image.fromarray(output_image).convert("L")
# --- 3. Prediction Logic ---
def predict_rps(image):
if image is None:
return None, None
# STEP 1: Remove Background
clean_image = remove_background(image)
# STEP 2: Smart Crop & Resize to 300x300
processed_image = ImageOps.fit(clean_image, (300, 300), Image.Resampling.LANCZOS)
# STEP 3: Prepare for Model
input_data = np.array(processed_image, dtype=np.float32)
input_data = np.expand_dims(input_data, axis=-1) / 255.0
input_data = np.expand_dims(input_data, axis=0)
# STEP 4: Inference using Interpreter
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
# STEP 5: Format Results
class_names = ['Rock ✊', 'Paper ✋', 'Scissors ✌️']
prediction_dict = {class_names[i]: float(output_data[0][i]) for i in range(3)}
return processed_image, prediction_dict
# --- 4. QR Code Helper ---
def generate_qr_code(url):
if not url:
return None
qr = qrcode.QRCode(version=1, box_size=10, border=4)
qr.add_data(url)
qr.make(fit=True)
return qr.make_image(fill_color="black", back_color="white").get_image()
# --- 5. Gradio UI ---
with gr.Blocks(theme="default", title="Edge AI: ✊Rock - ✋Paper - ✌️Scissors") as demo:
gr.Markdown("# Edge AI: ✊Rock - ✋Paper - ✌️Scissors")
gr.Markdown("This model is running on TFLite! Take a photo of your hand.")
with gr.Row():
with gr.Column():
input_img = gr.Image(
type="pil",
label="Snap a Photo",
sources=["webcam", "upload"]
)
predict_btn = gr.Button("Predict", variant="primary")
with gr.Column():
output_img = gr.Image(label="What the AI Sees")
output_label = gr.Label(num_top_classes=3, label="Prediction")
predict_btn.click(
fn=predict_rps,
inputs=input_img,
outputs=[output_img, output_label]
)
# QR Code section
gr.Markdown("---")
with gr.Row():
qr_display = gr.Image(label="Share this app", show_download_button=False, height=300, width=300)
url_state = gr.Textbox(visible=False)
# In Hugging Face Spaces, we just want to display the URL.
# The JS injection might behave differently inside an iframe, but we'll keep it simple.
demo.load(
fn=generate_qr_code,
inputs=[url_state],
outputs=[qr_display],
js="() => window.location.href"
)
if __name__ == "__main__":
demo.launch() |