Update handler.py
Browse files- handler.py +16 -4
handler.py
CHANGED
|
@@ -11,7 +11,19 @@ class EndpointHandler:
|
|
| 11 |
model, checkpoint=model_dir, device_map="auto"
|
| 12 |
) # 自动跨 GPU 切层
|
| 13 |
def __call__(self, data):
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
model, checkpoint=model_dir, device_map="auto"
|
| 12 |
) # 自动跨 GPU 切层
|
| 13 |
def __call__(self, data):
|
| 14 |
+
prompt = data["inputs"]
|
| 15 |
+
|
| 16 |
+
inputs = self.tokenizer(
|
| 17 |
+
prompt, return_tensors="pt"
|
| 18 |
+
).to("cuda:0") # 👈 把 input_ids/attention_mask 都放到 0 号卡
|
| 19 |
+
|
| 20 |
+
out_ids = self.model.generate(
|
| 21 |
+
**inputs,
|
| 22 |
+
max_new_tokens=256,
|
| 23 |
+
)
|
| 24 |
+
return {
|
| 25 |
+
"generated_text": self.tokenizer.decode(
|
| 26 |
+
out_ids[0], skip_special_tokens=True
|
| 27 |
+
)
|
| 28 |
+
}
|
| 29 |
+
|