ccclemenfff commited on
Commit
4832cce
·
1 Parent(s): eecb9b2
Files changed (1) hide show
  1. handler.py +31 -42
handler.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import base64
@@ -28,41 +29,40 @@ class EndpointHandler:
28
  bos_token_id=1,
29
  do_sample=False,
30
  temperature=0.7,
31
- max_new_tokens=10240
32
  )
33
 
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
35
- inputs = self.preprocess(data)
36
- prediction = self.inference(inputs)
37
- return self.postprocess(prediction)
 
 
 
38
 
39
  def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
40
  prompt = request["inputs"]
41
  image_b64 = request.get("image", None)
42
  video_b64 = request.get("video", None)
 
43
 
44
  pixel_values = None
45
 
46
  if image_b64:
47
  image_bytes = base64.b64decode(image_b64)
48
- pixel_values = self._load_image(image_bytes).unsqueeze(0) # [1, 3, 224, 224]
49
- if self.device == "cuda":
50
- pixel_values = pixel_values.half()
51
  pixel_values = pixel_values.to(self.device)
52
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
53
 
54
  elif video_b64:
55
  video_bytes = base64.b64decode(video_b64)
56
- pixel_values = self._load_video(video_bytes)
57
- if self.device == "cuda":
58
- pixel_values = pixel_values.half()
59
  pixel_values = pixel_values.to(self.device)
60
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
61
 
62
- return {
63
- "prompt": prompt,
64
- "pixel_values": pixel_values
65
- }
66
 
67
  def inference(self, inputs: Dict[str, Any]) -> str:
68
  prompt = inputs["prompt"]
@@ -72,37 +72,28 @@ class EndpointHandler:
72
  model_inputs.pop("token_type_ids", None)
73
  model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
74
 
 
75
  if pixel_values is not None:
76
- output = self.model.generate(
77
- **model_inputs,
78
- pixel_values=pixel_values,
79
- generation_config=self.gen_config,
80
- return_dict_in_generate=True,
81
- output_scores=True
82
- )
83
- else:
84
- output = self.model.language_model.generate(
85
- **model_inputs,
86
- generation_config=self.gen_config,
87
- return_dict_in_generate=True,
88
- output_scores=True
89
- )
90
- # 🧠 打印 debug 信息
91
  generated_ids = output.sequences[0]
92
- print("📍生成的 token ids:", generated_ids.tolist())
93
- raw_text = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
94
  clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
95
- print("🧾 带特殊符号的输出:", raw_text)
96
- print("✅ 去掉特殊符号的输出:", clean_text)
97
 
98
- return clean_text # 返回干净版本
99
  def postprocess(self, output: str) -> Dict[str, str]:
100
  return {"output": output.strip()}
101
 
102
  def _load_image(self, image_bytes: bytes) -> torch.Tensor:
103
  image = Image.open(BytesIO(image_bytes)).convert('RGB')
104
- crop_pct = 224 / 256
105
- size = int(224 / crop_pct)
106
  transform = T.Compose([
107
  T.Resize(size, interpolation=InterpolationMode.BICUBIC),
108
  T.CenterCrop(224),
@@ -111,7 +102,7 @@ class EndpointHandler:
111
  ])
112
  return transform(image)
113
 
114
- def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
115
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
116
  tmpfile.write(video_bytes)
117
  video_path = tmpfile.name
@@ -125,12 +116,10 @@ class EndpointHandler:
125
  T.Resize(224, interpolation=InterpolationMode.BICUBIC),
126
  T.CenterCrop(224),
127
  T.ToTensor(),
128
- T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
129
  ])
130
- processed = [transform(frame) for frame in frames] # each: [3, 224, 224]
131
- video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
132
- video_tensor = video_tensor.permute(1, 0, 2, 3) # [3, T, 224, 224]
133
- video_tensor = video_tensor.unsqueeze(0) # [1, 3, T, 224, 224] ✅
134
  return video_tensor
135
 
136
  def get_index(self, num_frames: int, num_segments: int):
 
1
+ ### ✅ handler.py(优化版)
2
  import os
3
  import torch
4
  import base64
 
29
  bos_token_id=1,
30
  do_sample=False,
31
  temperature=0.7,
32
+ max_new_tokens=1024
33
  )
34
 
35
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
36
+ try:
37
+ inputs = self.preprocess(data)
38
+ prediction = self.inference(inputs)
39
+ return self.postprocess(prediction)
40
+ except Exception as e:
41
+ return {"output": f"❌ 推理失败: {str(e)}"}
42
 
43
  def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
44
  prompt = request["inputs"]
45
  image_b64 = request.get("image", None)
46
  video_b64 = request.get("video", None)
47
+ num_segments = request.get("num_segments", 16)
48
 
49
  pixel_values = None
50
 
51
  if image_b64:
52
  image_bytes = base64.b64decode(image_b64)
53
+ pixel_values = self._load_image(image_bytes).unsqueeze(0)
54
+ pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
 
55
  pixel_values = pixel_values.to(self.device)
56
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
57
 
58
  elif video_b64:
59
  video_bytes = base64.b64decode(video_b64)
60
+ pixel_values = self._load_video(video_bytes, num_segments)
61
+ pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
 
62
  pixel_values = pixel_values.to(self.device)
63
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
64
 
65
+ return {"prompt": prompt, "pixel_values": pixel_values}
 
 
 
66
 
67
  def inference(self, inputs: Dict[str, Any]) -> str:
68
  prompt = inputs["prompt"]
 
72
  model_inputs.pop("token_type_ids", None)
73
  model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
74
 
75
+ print("📌 prompt token长度:", model_inputs["input_ids"].shape[1])
76
  if pixel_values is not None:
77
+ print("🎞️ pixel shape:", pixel_values.shape)
78
+
79
+ output = self.model.generate(
80
+ **model_inputs,
81
+ pixel_values=pixel_values,
82
+ generation_config=self.gen_config,
83
+ return_dict_in_generate=True,
84
+ output_scores=True
85
+ )
86
+
 
 
 
 
 
87
  generated_ids = output.sequences[0]
 
 
88
  clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
89
+ return clean_text
 
90
 
 
91
  def postprocess(self, output: str) -> Dict[str, str]:
92
  return {"output": output.strip()}
93
 
94
  def _load_image(self, image_bytes: bytes) -> torch.Tensor:
95
  image = Image.open(BytesIO(image_bytes)).convert('RGB')
96
+ size = int(224 / (224 / 256))
 
97
  transform = T.Compose([
98
  T.Resize(size, interpolation=InterpolationMode.BICUBIC),
99
  T.CenterCrop(224),
 
102
  ])
103
  return transform(image)
104
 
105
+ def _load_video(self, video_bytes: bytes, num_segments=16) -> torch.Tensor:
106
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
107
  tmpfile.write(video_bytes)
108
  video_path = tmpfile.name
 
116
  T.Resize(224, interpolation=InterpolationMode.BICUBIC),
117
  T.CenterCrop(224),
118
  T.ToTensor(),
119
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
120
  ])
121
+ processed = [transform(frame) for frame in frames]
122
+ video_tensor = torch.stack(processed, dim=0).permute(1, 0, 2, 3).unsqueeze(0)
 
 
123
  return video_tensor
124
 
125
  def get_index(self, num_frames: int, num_segments: int):