suhpau commited on
Commit
4105a85
Β·
verified Β·
1 Parent(s): 356f4f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -39
app.py CHANGED
@@ -1,73 +1,105 @@
1
- import gradio as gr
2
- import cv2
3
  import numpy as np
4
  import torch
 
 
5
  from transformers import AutoImageProcessor, VideoMAEForVideoClassification
6
- import tempfile
7
- import os
8
 
9
- MODEL_DIR = "models/hotcold_videomae"
 
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  processor = None
14
  model = None
15
 
16
- def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  global processor, model
 
 
 
 
 
 
 
18
  processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
19
  model = VideoMAEForVideoClassification.from_pretrained(MODEL_DIR)
20
  model.to(device)
21
  model.eval()
22
 
23
- def sample_frames(video_path, num_frames=16, size=224):
24
- cap = cv2.VideoCapture(video_path)
25
- frames = []
26
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
27
- idxs = np.linspace(0, max(total-1,0), num_frames).astype(int)
28
-
29
- cur = 0
30
- ret_frames = []
31
 
32
- while True:
33
- ok, frame = cap.read()
34
- if not ok:
35
- break
36
- if cur in idxs:
37
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
38
- frame = cv2.resize(frame, (size,size))
39
- ret_frames.append(frame)
40
- cur += 1
41
 
42
- cap.release()
 
 
 
 
 
 
43
 
44
- while len(ret_frames) < num_frames:
45
- ret_frames.append(ret_frames[-1])
 
46
 
47
- return ret_frames[:num_frames]
 
 
48
 
49
- @torch.no_grad()
50
- def predict(video):
51
  if model is None:
52
- load_model()
 
 
 
 
53
 
54
- frames = sample_frames(video)
55
- inputs = processor(frames, return_tensors="pt").to(device)
56
  outputs = model(**inputs)
57
- probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
58
 
59
- p_cold, p_hot = probs
60
- if p_hot > p_cold:
61
- return f"πŸ”₯ λ”μ›Œμš” (ν™•λ₯  {p_hot:.2f})"
62
  else:
63
- return f"❄️ μΆ”μ›Œμš” (ν™•λ₯  {p_cold:.2f})"
64
 
65
  demo = gr.Interface(
66
  fn=predict,
67
- inputs=gr.Video(label="μ‚¬λžŒ 행동 μ˜μƒ μ—…λ‘œλ“œ"),
68
  outputs="text",
69
  title="Hot / Cold Action Recognition",
70
- description="μ‚¬λžŒ 행동 μ˜μƒμ„ μ—…λ‘œλ“œν•˜λ©΄ λ”μš΄μ§€/μΆ”μš΄μ§€ νŒλ³„ν•©λ‹ˆλ‹€."
 
71
  )
72
 
73
  if __name__ == "__main__":
 
1
+ import os
 
2
  import numpy as np
3
  import torch
4
+ import gradio as gr
5
+
6
  from transformers import AutoImageProcessor, VideoMAEForVideoClassification
7
+ from decord import VideoReader, cpu
 
8
 
9
+ MODEL_DIR = "models/hotcold_videomae" # Space에 μ—…λ‘œλ“œν•œ λͺ¨λΈ 폴더 경둜
10
+ NUM_FRAMES = 16
11
+ SIZE = 224
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  processor = None
16
  model = None
17
 
18
+ def _extract_video_path(video_input):
19
+ """
20
+ Gradio Video input은 버전에 따라
21
+ - str (filepath)
22
+ - tuple (filepath, subtitle/...)
23
+ - dict {"name": filepath, ...}
24
+ ν˜•νƒœλ‘œ 올 수 μžˆμ–΄μ„œ μ „λΆ€ 처리
25
+ """
26
+ if video_input is None:
27
+ return None
28
+
29
+ if isinstance(video_input, str):
30
+ return video_input
31
+
32
+ if isinstance(video_input, (tuple, list)) and len(video_input) > 0:
33
+ return video_input[0]
34
+
35
+ if isinstance(video_input, dict):
36
+ # 보톡 {"name": ".../tmp/xxxx.mp4", ...}
37
+ return video_input.get("name") or video_input.get("path")
38
+
39
+ return None
40
+
41
+ def _load_model():
42
  global processor, model
43
+
44
+ if not os.path.isdir(MODEL_DIR):
45
+ raise RuntimeError(
46
+ f"❌ λͺ¨λΈ 폴더λ₯Ό 찾을 수 μ—†μ–΄μš”: {MODEL_DIR}\n"
47
+ f"Space 파일 λͺ©λ‘μ— 'models/hotcold_videomae/'κ°€ μžˆλŠ”μ§€ ν™•μΈν•΄μ€˜."
48
+ )
49
+
50
  processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
51
  model = VideoMAEForVideoClassification.from_pretrained(MODEL_DIR)
52
  model.to(device)
53
  model.eval()
54
 
55
+ def _sample_frames_decord(video_path, num_frames=NUM_FRAMES, size=SIZE):
56
+ vr = VideoReader(video_path, ctx=cpu(0))
57
+ total = len(vr)
58
+ if total <= 0:
59
+ raise RuntimeError("❌ μ˜μƒ ν”„λ ˆμž„μ„ 읽지 λͺ»ν–ˆμ–΄μš” (빈 μ˜μƒμΌ 수 있음).")
 
 
 
60
 
61
+ idxs = np.linspace(0, total - 1, num_frames).astype(int)
62
+ frames = vr.get_batch(idxs).asnumpy() # (T, H, W, 3) RGB
 
 
 
 
 
 
 
63
 
64
+ # resize (간단 버전)
65
+ import cv2
66
+ out = []
67
+ for f in frames:
68
+ f = cv2.resize(f, (size, size), interpolation=cv2.INTER_LINEAR)
69
+ out.append(f)
70
+ return out
71
 
72
+ @torch.no_grad()
73
+ def predict(video_input):
74
+ global processor, model
75
 
76
+ video_path = _extract_video_path(video_input)
77
+ if not video_path:
78
+ return "❌ μ˜μƒ 파일이 μ œλŒ€λ‘œ μ „λ‹¬λ˜μ§€ μ•Šμ•˜μ–΄μš”. λ‹€μ‹œ μ—…λ‘œλ“œν•΄μ€˜."
79
 
 
 
80
  if model is None:
81
+ _load_model()
82
+
83
+ frames = _sample_frames_decord(video_path)
84
+ inputs = processor(frames, return_tensors="pt")
85
+ inputs = {k: v.to(device) for k, v in inputs.items()}
86
 
 
 
87
  outputs = model(**inputs)
88
+ probs = torch.softmax(outputs.logits, dim=-1)[0].detach().cpu().numpy()
89
 
90
+ p_cold, p_hot = float(probs[0]), float(probs[1])
91
+ if p_hot >= p_cold:
92
+ return f"πŸ”₯ λ”μ›Œμš” (hot={p_hot:.2f}, cold={p_cold:.2f})"
93
  else:
94
+ return f"❄️ μΆ”μ›Œμš” (cold={p_cold:.2f}, hot={p_hot:.2f})"
95
 
96
  demo = gr.Interface(
97
  fn=predict,
98
+ inputs=gr.Video(label="행동 μ˜μƒ μ—…λ‘œλ“œ"),
99
  outputs="text",
100
  title="Hot / Cold Action Recognition",
101
+ description="μ‚¬λžŒ 행동 μ˜μƒμ„ 올리면 λ”μš΄μ§€/μΆ”μš΄μ§€ νŒλ³„ν•©λ‹ˆλ‹€.",
102
+ cache_examples=False, # Spacesμ—μ„œ mp4 μΊμ‹œ 문제 λ°©μ§€ 팁
103
  )
104
 
105
  if __name__ == "__main__":