Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import tempfile | |
| import base64 | |
| import numpy as np | |
| SEQUENCE_LENGTH = 16 | |
| NUM_CLASSES = 4 | |
| MODEL_PATH = "best_model.pth" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CLASS_NAMES = ["aggressive", "idle", "panic", "normal"] | |
| # ------------------ MODEL ------------------ | |
| class CNNLSTM(nn.Module): | |
| def __init__(self, num_classes): | |
| super(CNNLSTM, self).__init__() | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| ) | |
| self.lstm = nn.LSTM(64 * 16 * 16, 128, batch_first=True) | |
| self.fc = nn.Linear(128, num_classes) | |
| def forward(self, x): | |
| B, T, C, H, W = x.size() | |
| x = x.view(B*T, C, H, W) | |
| x = self.cnn(x) | |
| x = x.view(B, T, -1) | |
| x, _ = self.lstm(x) | |
| return self.fc(x[:, -1, :]) | |
| # ------------------ LOAD MODEL ------------------ | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError("Upload best_model.pth to the Space!") | |
| model = CNNLSTM(NUM_CLASSES).to(device) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| model.eval() | |
| return model | |
| try: | |
| model = load_model() | |
| except: | |
| model = None | |
| # ------------------ FRAME EXTRACTION (FFmpeg) ------------------ | |
| def extract_frames_ffmpeg(video_path): | |
| tmp_dir = tempfile.mkdtemp() | |
| out_pattern = os.path.join(tmp_dir, "frame_%03d.jpg") | |
| cmd = [ | |
| "ffmpeg", | |
| "-i", video_path, | |
| "-vf", "fps=1,scale=320:180", | |
| out_pattern, | |
| "-hide_banner", | |
| "-loglevel", "error" | |
| ] | |
| subprocess.run(cmd) | |
| jpgs = sorted([os.path.join(tmp_dir, f) for f in os.listdir(tmp_dir) if f.endswith(".jpg")]) | |
| if len(jpgs) == 0: | |
| return None | |
| if len(jpgs) >= SEQUENCE_LENGTH: | |
| idxs = np.linspace(0, len(jpgs)-1, SEQUENCE_LENGTH).astype(int) | |
| jpgs = [jpgs[i] for i in idxs] | |
| else: | |
| jpgs = (jpgs * SEQUENCE_LENGTH)[:SEQUENCE_LENGTH] | |
| return [Image.open(f).convert("RGB") for f in jpgs] | |
| # ------------------ PREDICTION ------------------ | |
| transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor(), | |
| ]) | |
| def do_predict(frames): | |
| if model is None: | |
| return {"Error": "Model not loaded"} | |
| tensors = [transform(f) for f in frames] | |
| tensor = torch.stack(tensors).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = model(tensor) | |
| probs = torch.softmax(out, dim=1)[0].cpu().numpy() | |
| return {CLASS_NAMES[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| def predict(files): | |
| if files is None: | |
| return {"Error": "Upload a file first!"} | |
| if isinstance(files, str): | |
| files = [files] | |
| # Video | |
| if len(files) == 1 and files[0].lower().endswith((".mp4",".mov",".avi",".mkv",".webm")): | |
| frames = extract_frames_ffmpeg(files[0]) | |
| if frames is None: | |
| return {"Error": "FFmpeg could not extract frames!"} | |
| return do_predict(frames) | |
| # Multiple images | |
| if len(files) >= SEQUENCE_LENGTH: | |
| imgs = [Image.open(f).convert("RGB") for f in files[:16]] | |
| return do_predict(imgs) | |
| # Single image | |
| try: | |
| img = Image.open(files[0]).convert("RGB") | |
| frames = [img] * SEQUENCE_LENGTH | |
| return do_predict(frames) | |
| except: | |
| return {"Error": "Invalid image"} | |
| # ------------------ CSS (insert via HTML) ------------------ | |
| css_html = """ | |
| <style> | |
| body, .gradio-container { | |
| background: #0b0f12 !important; | |
| color: white !important; | |
| } | |
| .glass { | |
| backdrop-filter: blur(12px) saturate(180%); | |
| background: rgba(255,255,255,0.06); | |
| border-radius: 16px; | |
| padding: 20px; | |
| border: 1px solid rgba(255,255,255,0.08); | |
| box-shadow: 0 4px 40px rgba(0,0,0,0.4); | |
| } | |
| </style> | |
| """ | |
| # ------------------ REACT FRONTEND (subtitle color updated) ------------------ | |
| react_html = """ | |
| <div class="glass"> | |
| <h1 style="margin:0;font-size:28px;color:red;">Crowd Behavior Analyzer</h1> | |
| <!-- subtitle color changed to a light tone (#E6EEF3) for readability --> | |
| <p style="color:#E6EEF3; opacity:0.95; margin-top:6px; margin-bottom:10px;"> | |
| Dark • Glassmorphism • React Autoplay Preview | |
| </p> | |
| <div id="react-root"></div> | |
| </div> | |
| <script crossorigin src="https://unpkg.com/react@18/umd/react.production.min.js"></script> | |
| <script crossorigin src="https://unpkg.com/react-dom@18/umd/react-dom.production.min.js"></script> | |
| <script> | |
| const e = React.createElement; | |
| function App(){ | |
| const [frames,setFrames] = React.useState([]); | |
| const [i,setI] = React.useState(0); | |
| React.useEffect(()=>{ | |
| const inp = document.getElementById("media_input"); | |
| if(!inp) return; | |
| inp.addEventListener("change",() =>{ | |
| const files = inp.files; | |
| if(!files || !files.length) return; | |
| const picks = [...files].slice(0,16).map(f => { | |
| return new Promise(res=>{ | |
| const r=new FileReader(); | |
| r.onload=()=>res(r.result); | |
| r.readAsDataURL(f); | |
| }); | |
| }); | |
| Promise.all(picks).then(data=>{ | |
| while(data.length < 16) data.push(data[0]); | |
| setFrames(data); | |
| setI(0); | |
| }); | |
| }); | |
| },[]); | |
| React.useEffect(()=>{ | |
| if(!frames.length) return; | |
| const t=setInterval(()=>setI(x=>(x+1)%frames.length),300); | |
| return ()=>clearInterval(t); | |
| },[frames]); | |
| return e("div",{}, | |
| frames.length | |
| ? e("img",{src:frames[i],style:{width:"100%",borderRadius:"12px"}}) | |
| : e("p",{style:{opacity:0.5}},"Preview will appear here after upload.") | |
| ); | |
| } | |
| ReactDOM.createRoot(document.getElementById("react-root")).render(e(App)); | |
| </script> | |
| """ | |
| # ------------------ UI ------------------ | |
| with gr.Blocks() as demo: | |
| gr.HTML(css_html) | |
| gr.HTML(react_html) | |
| file_input = gr.File( | |
| label="Upload video or multiple images", | |
| file_count="multiple", | |
| type="filepath", | |
| elem_id="media_input" | |
| ) | |
| btn = gr.Button("Analyze", variant="primary") | |
| output = gr.Label(num_top_classes=4) | |
| btn.click(predict, file_input, output) | |
| demo.launch() | |