ccclemenfff commited on
Commit
6b701b8
·
1 Parent(s): 31f8561
Files changed (3) hide show
  1. .idea/embodied_explainer.iml +1 -1
  2. .idea/misc.xml +7 -0
  3. handler.py +14 -9
.idea/embodied_explainer.iml CHANGED
@@ -2,7 +2,7 @@
2
  <module type="PYTHON_MODULE" version="4">
3
  <component name="NewModuleRootManager">
4
  <content url="file://$MODULE_DIR$" />
5
- <orderEntry type="inheritedJdk" />
6
  <orderEntry type="sourceFolder" forTests="false" />
7
  </component>
8
  <component name="PyDocumentationSettings">
 
2
  <module type="PYTHON_MODULE" version="4">
3
  <component name="NewModuleRootManager">
4
  <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="demo" jdkType="Python SDK" />
6
  <orderEntry type="sourceFolder" forTests="false" />
7
  </component>
8
  <component name="PyDocumentationSettings">
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="demo" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="demo" project-jdk-type="Python SDK" />
7
+ </project>
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  from PIL import Image
4
  from io import BytesIO
5
  from typing import Dict, Any
@@ -34,24 +35,28 @@ class EndpointHandler:
34
  )
35
 
36
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
37
- # Hugging Face 会调用这个函数,data 是原始输入
38
  inputs = self.preprocess(data)
39
  prediction = self.inference(inputs)
40
  return self.postprocess(prediction)
41
 
42
  def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
43
  prompt = request["inputs"]
44
- image = request.get("image", None)
45
- video = request.get("video", None)
46
 
47
- if image:
48
- pixel_values = self._load_image(image).unsqueeze(0).to(self.device)
 
 
 
 
49
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
50
- elif video:
51
- pixel_values = self._load_video(video).unsqueeze(0).to(self.device)
 
 
 
52
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
53
- else:
54
- pixel_values = None
55
 
56
  return {
57
  "prompt": prompt,
 
1
  import os
2
  import torch
3
+ import base64
4
  from PIL import Image
5
  from io import BytesIO
6
  from typing import Dict, Any
 
35
  )
36
 
37
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
38
  inputs = self.preprocess(data)
39
  prediction = self.inference(inputs)
40
  return self.postprocess(prediction)
41
 
42
  def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
43
  prompt = request["inputs"]
44
+ image_b64 = request.get("image", None)
45
+ video_b64 = request.get("video", None)
46
 
47
+ pixel_values = None
48
+
49
+ if image_b64:
50
+ # 关键改动:base64解码
51
+ image_bytes = base64.b64decode(image_b64)
52
+ pixel_values = self._load_image(image_bytes).unsqueeze(0).to(self.device)
53
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
54
+
55
+ elif video_b64:
56
+ # 关键改动:base64解码
57
+ video_bytes = base64.b64decode(video_b64)
58
+ pixel_values = self._load_video(video_bytes).unsqueeze(0).to(self.device)
59
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
 
 
60
 
61
  return {
62
  "prompt": prompt,