usman-khn's picture
Update app.py
a8313df verified
import os
import subprocess
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import tempfile
import base64
import numpy as np
SEQUENCE_LENGTH = 16
NUM_CLASSES = 4
MODEL_PATH = "best_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASS_NAMES = ["aggressive", "idle", "panic", "normal"]
# ------------------ MODEL ------------------
class CNNLSTM(nn.Module):
def __init__(self, num_classes):
super(CNNLSTM, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.lstm = nn.LSTM(64 * 16 * 16, 128, batch_first=True)
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
B, T, C, H, W = x.size()
x = x.view(B*T, C, H, W)
x = self.cnn(x)
x = x.view(B, T, -1)
x, _ = self.lstm(x)
return self.fc(x[:, -1, :])
# ------------------ LOAD MODEL ------------------
def load_model():
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError("Upload best_model.pth to the Space!")
model = CNNLSTM(NUM_CLASSES).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
return model
try:
model = load_model()
except:
model = None
# ------------------ FRAME EXTRACTION (FFmpeg) ------------------
def extract_frames_ffmpeg(video_path):
tmp_dir = tempfile.mkdtemp()
out_pattern = os.path.join(tmp_dir, "frame_%03d.jpg")
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", "fps=1,scale=320:180",
out_pattern,
"-hide_banner",
"-loglevel", "error"
]
subprocess.run(cmd)
jpgs = sorted([os.path.join(tmp_dir, f) for f in os.listdir(tmp_dir) if f.endswith(".jpg")])
if len(jpgs) == 0:
return None
if len(jpgs) >= SEQUENCE_LENGTH:
idxs = np.linspace(0, len(jpgs)-1, SEQUENCE_LENGTH).astype(int)
jpgs = [jpgs[i] for i in idxs]
else:
jpgs = (jpgs * SEQUENCE_LENGTH)[:SEQUENCE_LENGTH]
return [Image.open(f).convert("RGB") for f in jpgs]
# ------------------ PREDICTION ------------------
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
def do_predict(frames):
if model is None:
return {"Error": "Model not loaded"}
tensors = [transform(f) for f in frames]
tensor = torch.stack(tensors).unsqueeze(0).to(device)
with torch.no_grad():
out = model(tensor)
probs = torch.softmax(out, dim=1)[0].cpu().numpy()
return {CLASS_NAMES[i]: float(probs[i]) for i in range(NUM_CLASSES)}
def predict(files):
if files is None:
return {"Error": "Upload a file first!"}
if isinstance(files, str):
files = [files]
# Video
if len(files) == 1 and files[0].lower().endswith((".mp4",".mov",".avi",".mkv",".webm")):
frames = extract_frames_ffmpeg(files[0])
if frames is None:
return {"Error": "FFmpeg could not extract frames!"}
return do_predict(frames)
# Multiple images
if len(files) >= SEQUENCE_LENGTH:
imgs = [Image.open(f).convert("RGB") for f in files[:16]]
return do_predict(imgs)
# Single image
try:
img = Image.open(files[0]).convert("RGB")
frames = [img] * SEQUENCE_LENGTH
return do_predict(frames)
except:
return {"Error": "Invalid image"}
# ------------------ CSS (insert via HTML) ------------------
css_html = """
<style>
body, .gradio-container {
background: #0b0f12 !important;
color: white !important;
}
.glass {
backdrop-filter: blur(12px) saturate(180%);
background: rgba(255,255,255,0.06);
border-radius: 16px;
padding: 20px;
border: 1px solid rgba(255,255,255,0.08);
box-shadow: 0 4px 40px rgba(0,0,0,0.4);
}
</style>
"""
# ------------------ REACT FRONTEND (subtitle color updated) ------------------
react_html = """
<div class="glass">
<h1 style="margin:0;font-size:28px;color:red;">Crowd Behavior Analyzer</h1>
<!-- subtitle color changed to a light tone (#E6EEF3) for readability -->
<p style="color:#E6EEF3; opacity:0.95; margin-top:6px; margin-bottom:10px;">
Dark • Glassmorphism • React Autoplay Preview
</p>
<div id="react-root"></div>
</div>
<script crossorigin src="https://unpkg.com/react@18/umd/react.production.min.js"></script>
<script crossorigin src="https://unpkg.com/react-dom@18/umd/react-dom.production.min.js"></script>
<script>
const e = React.createElement;
function App(){
const [frames,setFrames] = React.useState([]);
const [i,setI] = React.useState(0);
React.useEffect(()=>{
const inp = document.getElementById("media_input");
if(!inp) return;
inp.addEventListener("change",() =>{
const files = inp.files;
if(!files || !files.length) return;
const picks = [...files].slice(0,16).map(f => {
return new Promise(res=>{
const r=new FileReader();
r.onload=()=>res(r.result);
r.readAsDataURL(f);
});
});
Promise.all(picks).then(data=>{
while(data.length < 16) data.push(data[0]);
setFrames(data);
setI(0);
});
});
},[]);
React.useEffect(()=>{
if(!frames.length) return;
const t=setInterval(()=>setI(x=>(x+1)%frames.length),300);
return ()=>clearInterval(t);
},[frames]);
return e("div",{},
frames.length
? e("img",{src:frames[i],style:{width:"100%",borderRadius:"12px"}})
: e("p",{style:{opacity:0.5}},"Preview will appear here after upload.")
);
}
ReactDOM.createRoot(document.getElementById("react-root")).render(e(App));
</script>
"""
# ------------------ UI ------------------
with gr.Blocks() as demo:
gr.HTML(css_html)
gr.HTML(react_html)
file_input = gr.File(
label="Upload video or multiple images",
file_count="multiple",
type="filepath",
elem_id="media_input"
)
btn = gr.Button("Analyze", variant="primary")
output = gr.Label(num_top_classes=4)
btn.click(predict, file_input, output)
demo.launch()