SafetyWatch / app.py
DarkMo0o's picture
Update app.py
ce7fac7 verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.efficientnet import preprocess_input
import base64
import io
from PIL import Image
# ── Load model ───────────────────────────────────────────────────────────────
print("Loading model...")
model = tf.keras.models.load_model("ViolenceDetector_EfficientNetB0.keras")
print("βœ… Model loaded!")
IMG_SIZE = (224, 224)
# ── Helper ───────────────────────────────────────────────────────────────────
def preprocess_frame(img_pil):
img = img_pil.convert("RGB").resize(IMG_SIZE)
arr = np.expand_dims(np.array(img).astype(np.float32), axis=0)
return preprocess_input(arr)
# ── Predict for Gradio UI ─────────────────────────────────────────────────────
def predict_image(img_pil):
if img_pil is None:
return {"Error": 1.0}
inp = preprocess_frame(img_pil)
prob = float(model.predict(inp, verbose=0)[0][0])
label = "πŸ”΄ Violence" if prob >= 0.5 else "🟒 Safe"
confidence = prob if prob >= 0.5 else 1 - prob
return {label: float(round(confidence, 4))}
# ── API via Gradio's built-in endpoint ───────────────────────────────────────
def predict_base64(frame_b64: str):
"""Accepts base64 image string, returns JSON result"""
try:
img_bytes = base64.b64decode(frame_b64)
img_pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
inp = preprocess_frame(img_pil)
prob = float(model.predict(inp, verbose=0)[0][0])
return {
"violence": prob >= 0.5,
"probability": round(prob, 4),
"label": "Violence" if prob >= 0.5 else "NonViolence",
"confidence": round(prob if prob >= 0.5 else 1 - prob, 4),
}
except Exception as e:
return {"error": str(e)}
# ── Gradio UI with API endpoints ──────────────────────────────────────────────
with gr.Blocks(title="Violence Detection API") as demo:
gr.Markdown("# πŸ” Violence Detection API")
gr.Markdown("Upload a video frame to detect violence. Also available as REST API.")
with gr.Tab("πŸ–ΌοΈ Test UI"):
with gr.Row():
img_input = gr.Image(type="pil", label="Upload Frame")
result_out = gr.Label(label="Result")
btn = gr.Button("Detect", variant="primary")
btn.click(fn=predict_image, inputs=img_input, outputs=result_out)
with gr.Tab("πŸ”Œ API - Base64"):
gr.Markdown("### POST to `/predict_base64`")
gr.Markdown("Send: `{frame: '<base64_image>'}` β†’ Receive: `{violence, probability, label}`")
b64_input = gr.Textbox(label="Base64 Image", placeholder="Paste base64 string here...")
b64_result = gr.JSON(label="Result")
b64_btn = gr.Button("Test API")
b64_btn.click(fn=predict_base64, inputs=b64_input, outputs=b64_result)
demo.launch()