ccclemenfff commited on
Commit
b8aad58
·
1 Parent(s): 58e4849
Files changed (1) hide show
  1. handler.py +5 -50
handler.py CHANGED
@@ -1,52 +1,7 @@
1
- import os
2
- from typing import Dict, Any
3
- from PIL import Image
4
- from io import BytesIO
5
 
6
- from inference import Chat # 直接import你放的inference.py里Chat类
7
- from robohusky.conversation import get_conv_template
8
 
9
- class EndpointHandler:
10
- def __init__(self, path: str = "."):
11
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
- self.chat = Chat(
13
- model_path=path,
14
- device=self.device,
15
- num_gpus=1,
16
- max_new_tokens=1024,
17
- load_8bit=False
18
- )
19
- self.vision_feature = None
20
- self.modal_type = "text"
21
- self.conv = get_conv_template("husky").copy()
22
-
23
- def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
24
- query = inputs.get("inputs", "")
25
- self.conv = get_conv_template("husky").copy()
26
- self.vision_feature = None
27
- self.modal_type = "text"
28
-
29
- if "image" in inputs:
30
- image_bytes = inputs["image"]
31
- image = Image.open(BytesIO(image_bytes)).convert("RGB")
32
- image.save("temp.jpg")
33
- self.vision_feature = self.chat.get_image_embedding("temp.jpg")
34
- self.modal_type = "image"
35
-
36
- elif "video" in inputs:
37
- video_bytes = inputs["video"]
38
- with open("temp.mp4", "wb") as f:
39
- f.write(video_bytes)
40
- self.vision_feature = self.chat.get_video_embedding("temp.mp4")
41
- self.modal_type = "video"
42
-
43
- return {"query": query}
44
-
45
- def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
46
- processed = self.preprocess(inputs)
47
- query = processed["query"]
48
-
49
- conversations = self.chat.ask(text=query, conv=self.conv, modal_type=self.modal_type)
50
- outputs = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type)
51
- self.conv.messages[-1][1] = outputs.strip()
52
- return {"output": outputs.strip()}
 
1
+ # handler.py
 
 
 
2
 
3
+ from inference import Chat
4
+ chat = Chat()
5
 
6
+ def inference_fn(inputs):
7
+ return chat.answer(inputs)