darisdzakwanhoesien commited on
Commit
6b6719d
·
verified ·
1 Parent(s): d617813

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -15
app.py CHANGED
@@ -2,9 +2,16 @@ import gradio as gr
2
  import json
3
  import torch
4
  import decord
5
- from transformers import AutoProcessor, AutoModelForVideoClassification, AutoTokenizer, AutoModel
 
 
 
 
 
6
 
7
- # --- 1. VideoMAE ---
 
 
8
  vm_model_id = "MCG-NJU/videomae-base-finetuned-kinetics"
9
  vm_processor = AutoProcessor.from_pretrained(vm_model_id)
10
  vm_model = AutoModelForVideoClassification.from_pretrained(vm_model_id)
@@ -25,32 +32,78 @@ def run_videomae(video):
25
  except Exception as e:
26
  return {"model": "VideoMAE", "error": str(e)}
27
 
28
- # --- 2. InternVideo2.5-Chat-8B ---
 
 
29
  iv_model_id = "OpenGVLab/InternVideo2_5_Chat_8B"
30
- iv_tokenizer = AutoTokenizer.from_pretrained(iv_model_id, trust_remote_code=True)
31
- iv_model = AutoModel.from_pretrained(iv_model_id, trust_remote_code=True).half().cuda().eval()
 
 
 
 
 
 
 
 
 
32
 
33
  def run_internvideo(video, prompt):
 
 
34
  try:
35
- # Simplified: they provide a .chat() API in trust_remote_code
36
- response, _ = iv_model.chat(iv_tokenizer, video_path=video, user_prompt=prompt, history=None)
37
- return {"model": "InternVideo2.5-Chat-8B", "output": response}
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
  return {"model": "InternVideo2.5-Chat-8B", "error": str(e)}
40
 
41
- # --- 3. LLaVA-Video-Llama-3.1-8B ---
 
 
42
  llava_model_id = "weizhiwang/LLaVA-Video-Llama-3.1-8B"
43
- llava_tokenizer = AutoTokenizer.from_pretrained(llava_model_id, trust_remote_code=True)
44
- llava_model = AutoModel.from_pretrained(llava_model_id, trust_remote_code=True).half().cuda().eval()
 
 
 
 
 
 
 
 
 
45
 
46
  def run_llava(video, prompt):
 
 
47
  try:
48
- response, _ = llava_model.chat(llava_tokenizer, video_path=video, user_prompt=prompt, history=None)
49
- return {"model": "LLaVA-Video-Llama-3.1-8B", "output": response}
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
  return {"model": "LLaVA-Video-Llama-3.1-8B", "error": str(e)}
52
 
53
- # --- Unified ---
 
 
54
  def analyze_all(video, prompt):
55
  results = []
56
  results.append(run_videomae(video))
@@ -58,10 +111,15 @@ def analyze_all(video, prompt):
58
  results.append(run_llava(video, prompt))
59
  return json.dumps(results, indent=2)
60
 
 
 
 
61
  demo = gr.Interface(
62
  fn=analyze_all,
63
  inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Prompt")],
64
- outputs="json"
 
 
65
  )
66
 
67
  if __name__ == "__main__":
 
2
  import json
3
  import torch
4
  import decord
5
+ from transformers import (
6
+ AutoProcessor,
7
+ AutoModelForVideoClassification,
8
+ AutoTokenizer,
9
+ AutoModel
10
+ )
11
 
12
+ # ------------------------------------------------------------
13
+ # 1. VideoMAE (simple classification)
14
+ # ------------------------------------------------------------
15
  vm_model_id = "MCG-NJU/videomae-base-finetuned-kinetics"
16
  vm_processor = AutoProcessor.from_pretrained(vm_model_id)
17
  vm_model = AutoModelForVideoClassification.from_pretrained(vm_model_id)
 
32
  except Exception as e:
33
  return {"model": "VideoMAE", "error": str(e)}
34
 
35
+ # ------------------------------------------------------------
36
+ # 2. InternVideo2.5-Chat-8B
37
+ # ------------------------------------------------------------
38
  iv_model_id = "OpenGVLab/InternVideo2_5_Chat_8B"
39
+ try:
40
+ iv_tokenizer = AutoTokenizer.from_pretrained(iv_model_id, trust_remote_code=True)
41
+ iv_model = AutoModel.from_pretrained(
42
+ iv_model_id,
43
+ trust_remote_code=True,
44
+ revision="main" # pin revision for stability
45
+ ).to(torch.bfloat16).cuda().eval()
46
+ except Exception as e:
47
+ iv_model = None
48
+ iv_tokenizer = None
49
+ iv_load_error = str(e)
50
 
51
  def run_internvideo(video, prompt):
52
+ if iv_model is None:
53
+ return {"model": "InternVideo2.5-Chat-8B", "error": iv_load_error}
54
  try:
55
+ # TODO: Replace with proper frame extraction & preprocessing from repo
56
+ question = "Describe this video."
57
+ output, _ = iv_model.chat(
58
+ iv_tokenizer,
59
+ None, # placeholder: pixel_values
60
+ question,
61
+ generation_config={"max_new_tokens": 256},
62
+ num_patches_list=[1],
63
+ history=None,
64
+ return_history=True
65
+ )
66
+ return {"model": "InternVideo2.5-Chat-8B", "output": output}
67
  except Exception as e:
68
  return {"model": "InternVideo2.5-Chat-8B", "error": str(e)}
69
 
70
+ # ------------------------------------------------------------
71
+ # 3. LLaVA-Video-Llama-3.1-8B
72
+ # ------------------------------------------------------------
73
  llava_model_id = "weizhiwang/LLaVA-Video-Llama-3.1-8B"
74
+ try:
75
+ lv_tokenizer = AutoTokenizer.from_pretrained(llava_model_id, trust_remote_code=True)
76
+ lv_model = AutoModel.from_pretrained(
77
+ llava_model_id,
78
+ trust_remote_code=True,
79
+ revision="main"
80
+ ).to(torch.bfloat16).cuda().eval()
81
+ except Exception as e:
82
+ lv_model = None
83
+ lv_tokenizer = None
84
+ lv_load_error = str(e)
85
 
86
  def run_llava(video, prompt):
87
+ if lv_model is None:
88
+ return {"model": "LLaVA-Video-Llama-3.1-8B", "error": lv_load_error}
89
  try:
90
+ # TODO: Replace with proper preprocessing from repo
91
+ output, _ = lv_model.chat(
92
+ lv_tokenizer,
93
+ None, # placeholder: pixel_values
94
+ prompt,
95
+ generation_config={"max_new_tokens": 256},
96
+ num_patches_list=[1],
97
+ history=None,
98
+ return_history=True
99
+ )
100
+ return {"model": "LLaVA-Video-Llama-3.1-8B", "output": output}
101
  except Exception as e:
102
  return {"model": "LLaVA-Video-Llama-3.1-8B", "error": str(e)}
103
 
104
+ # ------------------------------------------------------------
105
+ # Unified function
106
+ # ------------------------------------------------------------
107
  def analyze_all(video, prompt):
108
  results = []
109
  results.append(run_videomae(video))
 
111
  results.append(run_llava(video, prompt))
112
  return json.dumps(results, indent=2)
113
 
114
+ # ------------------------------------------------------------
115
+ # Gradio UI
116
+ # ------------------------------------------------------------
117
  demo = gr.Interface(
118
  fn=analyze_all,
119
  inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Prompt")],
120
+ outputs="json",
121
+ title="Multi-Model Video Analysis",
122
+ description="Runs the same video + prompt through VideoMAE, InternVideo2.5-Chat-8B, and LLaVA-Video-Llama-3.1-8B."
123
  )
124
 
125
  if __name__ == "__main__":