vibrantturtle commited on
Commit
4e5f09a
·
verified ·
1 Parent(s): 1a4a55b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import numpy as np
5
+ import torchvision.transforms as T
6
+ from decord import VideoReader, cpu
7
+ from PIL import Image
8
+ from torchvision.transforms.functional import InterpolationMode
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ MODEL_ID = "OpenGVLab/InternVideo2_5_Chat_8B"
12
+
13
+ # Load once at startup (Space will cache weights after first run)
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
15
+ model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).half().cuda().to(torch.bfloat16)
16
+ model.eval()
17
+
18
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
+ IMAGENET_STD = (0.229, 0.224, 0.225)
20
+
21
+ def build_transform(input_size=448):
22
+ return T.Compose([
23
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
24
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
25
+ T.ToTensor(),
26
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
27
+ ])
28
+
29
+ def sample_frames(video_path, num_segments=16, input_size=448):
30
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
31
+ max_frame = len(vr) - 1
32
+ if max_frame <= 0:
33
+ idxs = [0]
34
+ else:
35
+ idxs = np.linspace(0, max_frame, num_segments).astype(int).tolist()
36
+
37
+ transform = build_transform(input_size)
38
+ pixel_values_list = []
39
+ num_patches_list = []
40
+
41
+ # Simple: one tile per frame (keeps memory lower)
42
+ for i in idxs:
43
+ img = Image.fromarray(vr[i].asnumpy()).convert("RGB")
44
+ pv = transform(img).unsqueeze(0) # [1,3,H,W]
45
+ pixel_values_list.append(pv)
46
+ num_patches_list.append(1)
47
+
48
+ pixel_values = torch.cat(pixel_values_list, dim=0) # [T,3,H,W]
49
+ return pixel_values, num_patches_list
50
+
51
+ @spaces.GPU
52
+ @torch.no_grad()
53
+ def analyze(video, prompt, num_segments, max_new_tokens):
54
+ if video is None:
55
+ return "Upload a video first."
56
+
57
+ # gr.Video returns a dict-like object in some gradio versions;
58
+ # safest: handle both string path and dict
59
+ if isinstance(video, dict) and "path" in video:
60
+ video_path = video["path"]
61
+ else:
62
+ video_path = video
63
+
64
+ pixel_values, num_patches_list = sample_frames(
65
+ video_path,
66
+ num_segments=int(num_segments),
67
+ input_size=448
68
+ )
69
+ pixel_values = pixel_values.to(torch.bfloat16).to(model.device)
70
+
71
+ video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
72
+ question = video_prefix + (prompt or "Describe this video in detail.")
73
+
74
+ generation_config = dict(
75
+ do_sample=False,
76
+ temperature=0.0,
77
+ max_new_tokens=int(max_new_tokens),
78
+ top_p=0.1,
79
+ num_beams=1
80
+ )
81
+
82
+ out, _ = model.chat(
83
+ tokenizer,
84
+ pixel_values,
85
+ question,
86
+ generation_config,
87
+ num_patches_list=num_patches_list,
88
+ history=None,
89
+ return_history=True,
90
+ )
91
+ return out
92
+
93
+ demo = gr.Interface(
94
+ fn=analyze,
95
+ inputs=[
96
+ gr.Video(label="Upload video"),
97
+ gr.Textbox(label="Prompt", value="Describe what is happening. If someone is using a phone while driving, say so."),
98
+ gr.Slider(8, 64, value=16, step=8, label="Frames sampled (lower=faster/safer)"),
99
+ gr.Slider(64, 512, value=256, step=64, label="Max new tokens (lower=faster)"),
100
+ ],
101
+ outputs=gr.Textbox(label="Model output"),
102
+ title="InternVideo2.5 Chat 8B — Video Analysis Demo",
103
+ )
104
+
105
+ demo.launch()