ccclemenfff commited on
Commit
2d341f4
·
1 Parent(s): 95022b5
Files changed (1) hide show
  1. handler.py +42 -30
handler.py CHANGED
@@ -1,4 +1,3 @@
1
- ### ✅ handler.py(优化版)
2
  import os
3
  import torch
4
  import base64
@@ -29,40 +28,41 @@ class EndpointHandler:
29
  bos_token_id=1,
30
  do_sample=False,
31
  # temperature=0.7,
32
- max_new_tokens=1024
33
  )
34
 
35
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
36
- try:
37
- inputs = self.preprocess(data)
38
- prediction = self.inference(inputs)
39
- return self.postprocess(prediction)
40
- except Exception as e:
41
- return {"output": f"❌ 推理失败: {str(e)}"}
42
 
43
  def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
44
  prompt = request["inputs"]
45
  image_b64 = request.get("image", None)
46
  video_b64 = request.get("video", None)
47
- num_segments = request.get("num_segments", 16)
48
 
49
  pixel_values = None
50
 
51
  if image_b64:
52
  image_bytes = base64.b64decode(image_b64)
53
- pixel_values = self._load_image(image_bytes).unsqueeze(0)
54
- pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
 
55
  pixel_values = pixel_values.to(self.device)
56
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
57
 
58
  elif video_b64:
59
  video_bytes = base64.b64decode(video_b64)
60
- pixel_values = self._load_video(video_bytes, num_segments)
61
- pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
 
62
  pixel_values = pixel_values.to(self.device)
63
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
64
 
65
- return {"prompt": prompt, "pixel_values": pixel_values}
 
 
 
66
 
67
  def inference(self, inputs: Dict[str, Any]) -> str:
68
  prompt = inputs["prompt"]
@@ -72,28 +72,38 @@ class EndpointHandler:
72
  model_inputs.pop("token_type_ids", None)
73
  model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
74
 
75
- print("📌 prompt token长度:", model_inputs["input_ids"].shape[1])
76
  if pixel_values is not None:
77
- print("🎞️ pixel shape:", pixel_values.shape)
78
-
79
- output = self.model.generate(
80
- **model_inputs,
81
- pixel_values=pixel_values,
82
- generation_config=self.gen_config,
83
- return_dict_in_generate=True,
84
- output_scores=True
85
- )
86
-
 
 
 
 
 
 
87
  generated_ids = output.sequences[0]
 
 
88
  clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
89
- return clean_text
 
90
 
 
91
  def postprocess(self, output: str) -> Dict[str, str]:
92
  return {"output": output.strip()}
93
 
94
  def _load_image(self, image_bytes: bytes) -> torch.Tensor:
95
  image = Image.open(BytesIO(image_bytes)).convert('RGB')
96
- size = int(224 / (224 / 256))
 
97
  transform = T.Compose([
98
  T.Resize(size, interpolation=InterpolationMode.BICUBIC),
99
  T.CenterCrop(224),
@@ -116,10 +126,12 @@ class EndpointHandler:
116
  T.Resize(224, interpolation=InterpolationMode.BICUBIC),
117
  T.CenterCrop(224),
118
  T.ToTensor(),
119
- T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
120
  ])
121
- processed = [transform(frame) for frame in frames]
122
- video_tensor = torch.stack(processed, dim=0).permute(1, 0, 2, 3).unsqueeze(0)
 
 
123
  return video_tensor
124
 
125
  def get_index(self, num_frames: int, num_segments: int):
 
 
1
  import os
2
  import torch
3
  import base64
 
28
  bos_token_id=1,
29
  do_sample=False,
30
  # temperature=0.7,
31
+ max_new_tokens=4096
32
  )
33
 
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
35
+ inputs = self.preprocess(data)
36
+ prediction = self.inference(inputs)
37
+ return self.postprocess(prediction)
 
 
 
38
 
39
  def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
40
  prompt = request["inputs"]
41
  image_b64 = request.get("image", None)
42
  video_b64 = request.get("video", None)
 
43
 
44
  pixel_values = None
45
 
46
  if image_b64:
47
  image_bytes = base64.b64decode(image_b64)
48
+ pixel_values = self._load_image(image_bytes).unsqueeze(0) # [1, 3, 224, 224]
49
+ if self.device == "cuda":
50
+ pixel_values = pixel_values.half()
51
  pixel_values = pixel_values.to(self.device)
52
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
53
 
54
  elif video_b64:
55
  video_bytes = base64.b64decode(video_b64)
56
+ pixel_values = self._load_video(video_bytes)
57
+ if self.device == "cuda":
58
+ pixel_values = pixel_values.half()
59
  pixel_values = pixel_values.to(self.device)
60
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
61
 
62
+ return {
63
+ "prompt": prompt,
64
+ "pixel_values": pixel_values
65
+ }
66
 
67
  def inference(self, inputs: Dict[str, Any]) -> str:
68
  prompt = inputs["prompt"]
 
72
  model_inputs.pop("token_type_ids", None)
73
  model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
74
 
 
75
  if pixel_values is not None:
76
+ output = self.model.generate(
77
+ **model_inputs,
78
+ pixel_values=pixel_values,
79
+ max_new_tokens=self.gen_config.max_new_tokens, # 👈 显式传入
80
+ generation_config=self.gen_config,
81
+ return_dict_in_generate=True,
82
+ output_scores=True
83
+ )
84
+ else:
85
+ output = self.model.language_model.generate(
86
+ **model_inputs,
87
+ generation_config=self.gen_config,
88
+ return_dict_in_generate=True,
89
+ output_scores=True
90
+ )
91
+ # 🧠 打印 debug 信息
92
  generated_ids = output.sequences[0]
93
+ print("📍生成的 token ids:", generated_ids.tolist())
94
+ raw_text = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
95
  clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
96
+ print("🧾 带特殊符号的输出:", raw_text)
97
+ print("✅ 去掉特殊符号的输出:", clean_text)
98
 
99
+ return clean_text # 返回干净版本
100
  def postprocess(self, output: str) -> Dict[str, str]:
101
  return {"output": output.strip()}
102
 
103
  def _load_image(self, image_bytes: bytes) -> torch.Tensor:
104
  image = Image.open(BytesIO(image_bytes)).convert('RGB')
105
+ crop_pct = 224 / 256
106
+ size = int(224 / crop_pct)
107
  transform = T.Compose([
108
  T.Resize(size, interpolation=InterpolationMode.BICUBIC),
109
  T.CenterCrop(224),
 
126
  T.Resize(224, interpolation=InterpolationMode.BICUBIC),
127
  T.CenterCrop(224),
128
  T.ToTensor(),
129
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
130
  ])
131
+ processed = [transform(frame) for frame in frames] # each: [3, 224, 224]
132
+ video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
133
+ video_tensor = video_tensor.permute(1, 0, 2, 3) # [3, T, 224, 224]
134
+ video_tensor = video_tensor.unsqueeze(0) # [1, 3, T, 224, 224] ✅
135
  return video_tensor
136
 
137
  def get_index(self, num_frames: int, num_segments: int):