Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.applications.efficientnet import preprocess_input | |
| import base64 | |
| import io | |
| from PIL import Image | |
| # ββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading model...") | |
| model = tf.keras.models.load_model("ViolenceDetector_EfficientNetB0.keras") | |
| print("β Model loaded!") | |
| IMG_SIZE = (224, 224) | |
| # ββ Helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess_frame(img_pil): | |
| img = img_pil.convert("RGB").resize(IMG_SIZE) | |
| arr = np.expand_dims(np.array(img).astype(np.float32), axis=0) | |
| return preprocess_input(arr) | |
| # ββ Predict for Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_image(img_pil): | |
| if img_pil is None: | |
| return {"Error": 1.0} | |
| inp = preprocess_frame(img_pil) | |
| prob = float(model.predict(inp, verbose=0)[0][0]) | |
| label = "π΄ Violence" if prob >= 0.5 else "π’ Safe" | |
| confidence = prob if prob >= 0.5 else 1 - prob | |
| return {label: float(round(confidence, 4))} | |
| # ββ API via Gradio's built-in endpoint βββββββββββββββββββββββββββββββββββββββ | |
| def predict_base64(frame_b64: str): | |
| """Accepts base64 image string, returns JSON result""" | |
| try: | |
| img_bytes = base64.b64decode(frame_b64) | |
| img_pil = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| inp = preprocess_frame(img_pil) | |
| prob = float(model.predict(inp, verbose=0)[0][0]) | |
| return { | |
| "violence": prob >= 0.5, | |
| "probability": round(prob, 4), | |
| "label": "Violence" if prob >= 0.5 else "NonViolence", | |
| "confidence": round(prob if prob >= 0.5 else 1 - prob, 4), | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # ββ Gradio UI with API endpoints ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Violence Detection API") as demo: | |
| gr.Markdown("# π Violence Detection API") | |
| gr.Markdown("Upload a video frame to detect violence. Also available as REST API.") | |
| with gr.Tab("πΌοΈ Test UI"): | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", label="Upload Frame") | |
| result_out = gr.Label(label="Result") | |
| btn = gr.Button("Detect", variant="primary") | |
| btn.click(fn=predict_image, inputs=img_input, outputs=result_out) | |
| with gr.Tab("π API - Base64"): | |
| gr.Markdown("### POST to `/predict_base64`") | |
| gr.Markdown("Send: `{frame: '<base64_image>'}` β Receive: `{violence, probability, label}`") | |
| b64_input = gr.Textbox(label="Base64 Image", placeholder="Paste base64 string here...") | |
| b64_result = gr.JSON(label="Result") | |
| b64_btn = gr.Button("Test API") | |
| b64_btn.click(fn=predict_base64, inputs=b64_input, outputs=b64_result) | |
| demo.launch() |