ccclemenfff commited on
Commit
d946509
·
1 Parent(s): 6b701b8
Files changed (1) hide show
  1. handler.py +13 -4
handler.py CHANGED
@@ -47,15 +47,24 @@ class EndpointHandler:
47
  pixel_values = None
48
 
49
  if image_b64:
50
- # 关键改动:base64解码
51
  image_bytes = base64.b64decode(image_b64)
52
- pixel_values = self._load_image(image_bytes).unsqueeze(0).to(self.device)
 
 
 
 
 
 
53
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
54
 
55
  elif video_b64:
56
- # 关键改动:base64解码
57
  video_bytes = base64.b64decode(video_b64)
58
- pixel_values = self._load_video(video_bytes).unsqueeze(0).to(self.device)
 
 
 
 
 
59
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
60
 
61
  return {
 
47
  pixel_values = None
48
 
49
  if image_b64:
 
50
  image_bytes = base64.b64decode(image_b64)
51
+ pixel_values = self._load_image(image_bytes).unsqueeze(0)
52
+
53
+ # ⭐️ 如果模型是 float16,就把输入也变成 half
54
+ if self.device == "cuda":
55
+ pixel_values = pixel_values.half()
56
+
57
+ pixel_values = pixel_values.to(self.device)
58
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
59
 
60
  elif video_b64:
 
61
  video_bytes = base64.b64decode(video_b64)
62
+ pixel_values = self._load_video(video_bytes).unsqueeze(0)
63
+
64
+ if self.device == "cuda":
65
+ pixel_values = pixel_values.half()
66
+
67
+ pixel_values = pixel_values.to(self.device)
68
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
69
 
70
  return {