eddywu commited on
Commit
43f8af8
·
verified ·
1 Parent(s): ef2a23b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -12
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import gradio as gr
3
  import spaces
4
  import torch
@@ -9,6 +9,91 @@ from qwen_vl_utils import process_vision_info
9
  # --- 配置區 ---
10
  REPO_ID = "Memories-ai/security_model"
11
  TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # 載入模型(用私有 token),自動上 GPU
14
  @lru_cache(maxsize=1)
@@ -89,7 +174,13 @@ def caption_video(video_path: str) -> str:
89
  if not video_path:
90
  return "No video provided."
91
 
 
92
  model, processor = _load_model_and_processor()
 
 
 
 
 
93
  messages = [
94
  {
95
  "role": "user",
@@ -101,6 +192,7 @@ def caption_video(video_path: str) -> str:
101
  ]
102
 
103
  # 建構聊天模板與多模態輸入
 
104
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
  image_inputs, video_inputs, video_kwargs = process_vision_info(
106
  messages, return_video_kwargs=True
@@ -118,19 +210,48 @@ def caption_video(video_path: str) -> str:
118
  # 上 GPU(若可)
119
  if torch.cuda.is_available():
120
  inputs = inputs.to("cuda")
 
 
 
121
 
 
 
 
 
 
 
 
122
  with torch.inference_mode():
123
- generated_ids = model.generate(**inputs, max_new_tokens=768)
124
- generated_ids_trimmed = [
125
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
126
- ]
127
- output_text = processor.batch_decode(
128
- generated_ids_trimmed,
129
- skip_special_tokens=True,
130
- clean_up_tokenization_spaces=False
131
- )
132
-
133
- return output_text[0] if output_text else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Gradio 介面
136
  demo = gr.Interface(
 
1
+ import os, json, time, subprocess, tempfile, shutil
2
  import gradio as gr
3
  import spaces
4
  import torch
 
9
  # --- 配置區 ---
10
  REPO_ID = "Memories-ai/security_model"
11
  TOKEN = os.environ.get("HF_TOKEN")
12
+ MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "160")) # 原 768 太高,先收斂
13
+ FORCE_FPS = int(os.environ.get("FORCE_FPS", "6")) # 影片抽幀 6fps 足夠 caption
14
+ TARGET_MAX_W = int(os.environ.get("TARGET_MAX_W", "1280")) # 寬度上限 1280 (<=720p)
15
+ DEBUG_TIMINGS = os.environ.get("DEBUG_TIMINGS", "0") == "1" # 1 時把分段時間附在輸出
16
+
17
+ # 速度小優化(Ampere 以後有效)
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+ try:
20
+ torch.set_float32_matmul_precision("high")
21
+ except Exception:
22
+ pass
23
+
24
+ # ---------- 實用工具:ffprobe & 可能轉碼 ----------
25
+ def _run_quiet(cmd: list[str]):
26
+ subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
27
+
28
+ def ffprobe_meta(path: str):
29
+ try:
30
+ out = subprocess.check_output([
31
+ "ffprobe","-v","error","-select_streams","v:0",
32
+ "-show_entries","stream=codec_name,width,height,avg_frame_rate",
33
+ "-of","json", path
34
+ ])
35
+ data = json.loads(out.decode("utf-8"))
36
+ st = data["streams"][0] if data.get("streams") else {}
37
+ fps = 0.0
38
+ afr = st.get("avg_frame_rate","0/0")
39
+ if isinstance(afr,str) and "/" in afr:
40
+ num, den = afr.split("/")
41
+ fps = float(num)/float(den) if float(den) != 0 else 0.0
42
+ return {
43
+ "codec": st.get("codec_name"),
44
+ "w": int(st.get("width") or 0),
45
+ "h": int(st.get("height") or 0),
46
+ "fps": fps
47
+ }
48
+ except Exception:
49
+ return {"codec": None, "w": 0, "h": 0, "fps": 0.0}
50
+
51
+ def maybe_transcode(input_path: str):
52
+ """
53
+ 碰到 HEVC/H.265 或解析度太大時,快速轉成 H.264 + yuv420p + 目標寬度 + 限制 FPS
54
+ 轉完回傳 (path, used_temp=True/False, reason)
55
+ """
56
+ meta = ffprobe_meta(input_path)
57
+ codec, w, h, fps = meta["codec"], meta["w"], meta["h"], meta["fps"]
58
+
59
+ need_codec_fix = codec in ("hevc","h265")
60
+ need_resize = (w and w > TARGET_MAX_W)
61
+ need_fps = (fps and fps > FORCE_FPS + 0.5)
62
+
63
+ if not (need_codec_fix or need_resize or need_fps):
64
+ return input_path, False, {"meta": meta, "transcoded": False}
65
+
66
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
67
+ out_path = tmp.name; tmp.close()
68
+
69
+ # scale 只在寬度超標時啟動,保留比例;fps 超標則限速
70
+ vf_parts = []
71
+ if need_resize:
72
+ vf_parts.append(f"scale='min({TARGET_MAX_W},iw)':-2")
73
+ if need_fps:
74
+ vf_parts.append(f"fps={FORCE_FPS}")
75
+ vf = ",".join(vf_parts) if vf_parts else "scale=trunc(iw/2)*2:trunc(ih/2)*2"
76
+
77
+ cmd = [
78
+ "ffmpeg","-y","-i", input_path,
79
+ "-vsync","vfr",
80
+ "-c:v","libx264","-preset","veryfast","-crf","23",
81
+ "-pix_fmt","yuv420p",
82
+ "-vf", vf,
83
+ "-c:a","aac","-b:a","128k",
84
+ "-movflags","+faststart",
85
+ out_path
86
+ ]
87
+ _run_quiet(cmd)
88
+ return out_path, True, {"meta": meta, "transcoded": True, "vf": vf}
89
+
90
+ # ---------- 分段計時 ----------
91
+ class Timer:
92
+ def __init__(self): self.t0=time.perf_counter(); self.spans=[]
93
+ def mark(self, name, dur): self.spans.append((name, round(dur,3)))
94
+ def result(self):
95
+ total = round(time.perf_counter()-self.t0, 3)
96
+ return {"total_s": total, **{k:v for k,v in self.spans}}
97
 
98
  # 載入模型(用私有 token),自動上 GPU
99
  @lru_cache(maxsize=1)
 
174
  if not video_path:
175
  return "No video provided."
176
 
177
+ T = Timer()
178
  model, processor = _load_model_and_processor()
179
+ # 1) 可能轉碼 / 降維 / 限 FPS
180
+ t = time.perf_counter()
181
+ safe_path, used_temp, tr_info = maybe_transcode(video_path)
182
+ T.mark("maybe_transcode_s", time.perf_counter()-t)
183
+
184
  messages = [
185
  {
186
  "role": "user",
 
192
  ]
193
 
194
  # 建構聊天模板與多模態輸入
195
+ t = time.perf_counter()
196
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
197
  image_inputs, video_inputs, video_kwargs = process_vision_info(
198
  messages, return_video_kwargs=True
 
210
  # 上 GPU(若可)
211
  if torch.cuda.is_available():
212
  inputs = inputs.to("cuda")
213
+ torch.cuda.synchronize()
214
+ T.mark("preprocess_s", time.perf_counter()-t)
215
+
216
 
217
+ gen_kwargs = dict(
218
+ max_new_tokens=MAX_NEW_TOKENS,
219
+ do_sample=False, # caption 任務較適合確定性解碼,速度更快
220
+ temperature=0.0,
221
+ top_p=1.0
222
+ )
223
+ t = time.perf_counter()
224
  with torch.inference_mode():
225
+ generated_ids = model.generate(**inputs, **gen_kwargs)
226
+ if torch.cuda.is_available(): torch.cuda.synchronize()
227
+ T.mark("generate_s", time.perf_counter()-t)
228
+
229
+
230
+ # 5) 後處理
231
+ t = time.perf_counter()
232
+ generated_ids_trimmed = [
233
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
234
+ ]
235
+ output_text = processor.batch_decode(
236
+ generated_ids_trimmed,
237
+ skip_special_tokens=True,
238
+ clean_up_tokenization_spaces=False
239
+ )
240
+ T.mark("postprocess_s", time.perf_counter()-t)
241
+
242
+ # 6) 清理暫存檔
243
+ if used_temp:
244
+ try: os.remove(safe_path)
245
+ except Exception: pass
246
+
247
+ # 打印詳細 timing 到日誌(HF Spaces Logs 可見)
248
+ print({"timings": T.result(), "transcode": tr_info})
249
+
250
+ caption = (output_text[0] if output_text else "").strip()
251
+ if DEBUG_TIMINGS:
252
+ rt = T.result()
253
+ caption += f"\n\n[timings] total={rt['total_s']}s, transcode={rt.get('maybe_transcode_s','-')}s, preprocess={rt.get('preprocess_s','-')}s, generate={rt.get('generate_s','-')}s"
254
+ return caption
255
 
256
  # Gradio 介面
257
  demo = gr.Interface(