Update handler.py
Browse files- handler.py +2 -2
handler.py
CHANGED
|
@@ -27,10 +27,10 @@ class EndpointHandler:
|
|
| 27 |
# pass inputs with all kwargs in data
|
| 28 |
if parameters is not None:
|
| 29 |
with torch.autocast("cuda"):
|
| 30 |
-
outputs = self.model.generate(**inputs, **parameters)
|
| 31 |
else:
|
| 32 |
with torch.autocast("cuda"):
|
| 33 |
-
outputs = self.model.generate(**inputs,)
|
| 34 |
|
| 35 |
# postprocess the prediction
|
| 36 |
prediction = outputs[0].cpu().numpy().tolist()
|
|
|
|
| 27 |
# pass inputs with all kwargs in data
|
| 28 |
if parameters is not None:
|
| 29 |
with torch.autocast("cuda"):
|
| 30 |
+
outputs = self.model.generate(**inputs, **parameters, do_sample=True, guidance_scale=3)
|
| 31 |
else:
|
| 32 |
with torch.autocast("cuda"):
|
| 33 |
+
outputs = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=450)
|
| 34 |
|
| 35 |
# postprocess the prediction
|
| 36 |
prediction = outputs[0].cpu().numpy().tolist()
|