developer0hye commited on
Commit
6802457
Β·
verified Β·
1 Parent(s): 1c13a5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -100,10 +100,10 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
100
  ])
101
  return frame_indices
102
 
103
- def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=8):
104
  """
105
  InternVL μ˜ˆμ‹œ μ½”λ“œ μ°Έκ³ : μ—¬λŸ¬ ν”„λ ˆμž„μ„ μΆ”μΆœν•˜μ—¬ dynamic_preprocess 적용.
106
- μ—¬κΈ°μ„œλŠ” 기본적으둜 num_segments=8둜 μ„€μ •.
107
  """
108
  vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
109
  max_frame = len(vr) - 1
@@ -130,15 +130,17 @@ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=8
130
  # =============================================================================
131
  # InternVL λͺ¨λΈ λ‘œλ”©
132
  # =============================================================================
133
- MODEL_ID = "OpenGVLab/InternVL2_5-8B"
134
 
135
  model = AutoModel.from_pretrained(
136
  MODEL_ID,
137
  torch_dtype=torch.bfloat16,
 
138
  low_cpu_mem_usage=True,
139
  use_flash_attn=True,
140
- trust_remote_code=True
141
- ).eval().cuda()
 
142
 
143
  tokenizer = AutoTokenizer.from_pretrained(
144
  MODEL_ID,
@@ -147,7 +149,7 @@ tokenizer = AutoTokenizer.from_pretrained(
147
  )
148
 
149
  # Gradio 상단에 ν‘œμ‹œν•  μ„€λͺ… 문ꡬ
150
- DESCRIPTION = "[InternVL2_5-8B Demo](https://github.com/OpenGVLab/InternVL) - Using the InternVL2_5-8B"
151
 
152
  image_extensions = Image.registered_extensions()
153
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
@@ -234,27 +236,29 @@ def internvl_inference(media_input, text_input=None):
234
  pixel_values = load_image(media_path, max_num=12)
235
  pixel_values = pixel_values.to(torch.bfloat16).cuda() # (N, 3, H, W)
236
  # InternVL λŒ€ν™”
237
- question = f"<image>\n{text_input}" if text_input else "<image>\n"
238
  generation_config = dict(max_new_tokens=1024, do_sample=True)
239
 
240
  response = model.chat(
241
  tokenizer,
242
  pixel_values,
243
  question,
244
- generation_config
 
 
245
  )
246
  return response
247
 
248
  elif media_type == "video":
249
- # μ˜μƒ: μ˜ˆμ‹œλ‘œ 첫 8ν”„λ ˆμž„μ— λŒ€ν•΄ 처리
250
  pixel_values, num_patches_list = load_video(
251
  media_path,
252
- num_segments=8,
253
  max_num=1
254
  )
255
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
256
  question_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
257
- question = question_prefix + (text_input if text_input else "")
258
  generation_config = dict(max_new_tokens=1024, do_sample=True)
259
 
260
  # μ˜μƒμ—μ„œλ„ λ™μΌν•œ chat() ν•¨μˆ˜ μ‚¬μš©
@@ -263,7 +267,9 @@ def internvl_inference(media_input, text_input=None):
263
  pixel_values,
264
  question,
265
  generation_config,
266
- num_patches_list=num_patches_list
 
 
267
  )
268
  return response
269
 
 
100
  ])
101
  return frame_indices
102
 
103
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
104
  """
105
  InternVL μ˜ˆμ‹œ μ½”λ“œ μ°Έκ³ : μ—¬λŸ¬ ν”„λ ˆμž„μ„ μΆ”μΆœν•˜μ—¬ dynamic_preprocess 적용.
106
+ μ—¬κΈ°μ„œλŠ” 기본적으둜 num_segments=32둜 μ„€μ •.
107
  """
108
  vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
109
  max_frame = len(vr) - 1
 
130
  # =============================================================================
131
  # InternVL λͺ¨λΈ λ‘œλ”©
132
  # =============================================================================
133
+ MODEL_ID = "OpenGVLab/InternVL3_5-8B"
134
 
135
  model = AutoModel.from_pretrained(
136
  MODEL_ID,
137
  torch_dtype=torch.bfloat16,
138
+ load_in_8bit=False,
139
  low_cpu_mem_usage=True,
140
  use_flash_attn=True,
141
+ trust_remote_code=True,
142
+ device_map="auto"
143
+ ).eval()
144
 
145
  tokenizer = AutoTokenizer.from_pretrained(
146
  MODEL_ID,
 
149
  )
150
 
151
  # Gradio 상단에 ν‘œμ‹œν•  μ„€λͺ… 문ꡬ
152
+ DESCRIPTION = "[InternVL3.5-8B Demo](https://github.com/OpenGVLab/InternVL) - Using the InternVL3.5-8B"
153
 
154
  image_extensions = Image.registered_extensions()
155
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
 
236
  pixel_values = load_image(media_path, max_num=12)
237
  pixel_values = pixel_values.to(torch.bfloat16).cuda() # (N, 3, H, W)
238
  # InternVL λŒ€ν™”
239
+ question = f"<image>\n{text_input}" if text_input else "<image>\nPlease describe the image."
240
  generation_config = dict(max_new_tokens=1024, do_sample=True)
241
 
242
  response = model.chat(
243
  tokenizer,
244
  pixel_values,
245
  question,
246
+ generation_config,
247
+ history=None,
248
+ return_history=False
249
  )
250
  return response
251
 
252
  elif media_type == "video":
253
+ # μ˜μƒ: μ˜ˆμ‹œλ‘œ 32ν”„λ ˆμž„μ— λŒ€ν•΄ 처리
254
  pixel_values, num_patches_list = load_video(
255
  media_path,
256
+ num_segments=32,
257
  max_num=1
258
  )
259
  pixel_values = pixel_values.to(torch.bfloat16).cuda()
260
  question_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
261
+ question = question_prefix + (text_input if text_input else "Describe this video in detail.")
262
  generation_config = dict(max_new_tokens=1024, do_sample=True)
263
 
264
  # μ˜μƒμ—μ„œλ„ λ™μΌν•œ chat() ν•¨μˆ˜ μ‚¬μš©
 
267
  pixel_values,
268
  question,
269
  generation_config,
270
+ num_patches_list=num_patches_list,
271
+ history=None,
272
+ return_history=False
273
  )
274
  return response
275