Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import qrcode | |
| import numpy as np | |
| import gradio as gr | |
| import mediapipe as mp | |
| import tensorflow as tf | |
| from PIL import Image, ImageOps | |
| # --- 1. Load the Pre-trained TFLite Model --- | |
| # We assume 'model.tflite' is uploaded alongside this script | |
| interpreter = tf.lite.Interpreter(model_path="model.tflite") | |
| interpreter.allocate_tensors() | |
| # Get input and output details | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| print("Interpreter is ready for edge predictions!") | |
| # --- 2. Setup MediaPipe for Background Removal --- | |
| mp_selfie_segmentation = mp.solutions.selfie_segmentation | |
| segmenter = mp_selfie_segmentation.SelfieSegmentation(model_selection=0) | |
| def remove_background(image_pil): | |
| """ | |
| Uses MediaPipe to segment the person from the background and replaces the background with white. | |
| """ | |
| # Convert PIL image to numpy array (RGB) | |
| image_np = np.array(image_pil.convert("RGB"), copy=True) | |
| # Get the segmentation mask | |
| results = segmenter.process(image_np) | |
| mask = results.segmentation_mask | |
| # Create a condition: If probability > 0.5, it's a hand/person | |
| condition = np.stack((mask,) * 3, axis=-1) > 0.5 | |
| # Create a solid white background | |
| bg_image = np.ones(image_np.shape, dtype=np.uint8) * 255 | |
| # Combine: Keep original if condition is true, else use white background | |
| output_image = np.where(condition, image_np, bg_image) | |
| # Return as PIL Image | |
| return Image.fromarray(output_image).convert("L") | |
| # --- 3. Prediction Logic --- | |
| def predict_rps(image): | |
| if image is None: | |
| return None, None | |
| # STEP 1: Remove Background | |
| clean_image = remove_background(image) | |
| # STEP 2: Smart Crop & Resize to 300x300 | |
| processed_image = ImageOps.fit(clean_image, (300, 300), Image.Resampling.LANCZOS) | |
| # STEP 3: Prepare for Model | |
| input_data = np.array(processed_image, dtype=np.float32) | |
| input_data = np.expand_dims(input_data, axis=-1) / 255.0 | |
| input_data = np.expand_dims(input_data, axis=0) | |
| # STEP 4: Inference using Interpreter | |
| interpreter.set_tensor(input_details[0]['index'], input_data) | |
| interpreter.invoke() | |
| output_data = interpreter.get_tensor(output_details[0]['index']) | |
| # STEP 5: Format Results | |
| class_names = ['Rock ✊', 'Paper ✋', 'Scissors ✌️'] | |
| prediction_dict = {class_names[i]: float(output_data[0][i]) for i in range(3)} | |
| return processed_image, prediction_dict | |
| # --- 4. QR Code Helper --- | |
| def generate_qr_code(url): | |
| if not url: | |
| return None | |
| qr = qrcode.QRCode(version=1, box_size=10, border=4) | |
| qr.add_data(url) | |
| qr.make(fit=True) | |
| return qr.make_image(fill_color="black", back_color="white").get_image() | |
| # --- 5. Gradio UI --- | |
| with gr.Blocks(theme="default", title="Edge AI: ✊Rock - ✋Paper - ✌️Scissors") as demo: | |
| gr.Markdown("# Edge AI: ✊Rock - ✋Paper - ✌️Scissors") | |
| gr.Markdown("This model is running on TFLite! Take a photo of your hand.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image( | |
| type="pil", | |
| label="Snap a Photo", | |
| sources=["webcam", "upload"] | |
| ) | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| with gr.Column(): | |
| output_img = gr.Image(label="What the AI Sees") | |
| output_label = gr.Label(num_top_classes=3, label="Prediction") | |
| predict_btn.click( | |
| fn=predict_rps, | |
| inputs=input_img, | |
| outputs=[output_img, output_label] | |
| ) | |
| # QR Code section | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| qr_display = gr.Image(label="Share this app", show_download_button=False, height=300, width=300) | |
| url_state = gr.Textbox(visible=False) | |
| # In Hugging Face Spaces, we just want to display the URL. | |
| # The JS injection might behave differently inside an iframe, but we'll keep it simple. | |
| demo.load( | |
| fn=generate_qr_code, | |
| inputs=[url_state], | |
| outputs=[qr_display], | |
| js="() => window.location.href" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |