Update handler.py
Browse files- handler.py +4 -4
handler.py
CHANGED
|
@@ -59,15 +59,15 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 59 |
# return response
|
| 60 |
|
| 61 |
class EndpointHandler:
|
| 62 |
-
def __init__(self):
|
| 63 |
# Load processor and model
|
| 64 |
self.PROCESSOR = AutoProcessor.from_pretrained(
|
| 65 |
-
|
| 66 |
trust_remote_code=True,
|
| 67 |
# token=API_TOKEN,
|
| 68 |
)
|
| 69 |
self.MODEL = AutoModelForCausalLM.from_pretrained(
|
| 70 |
-
|
| 71 |
# token=API_TOKEN,
|
| 72 |
trust_remote_code=True,
|
| 73 |
torch_dtype=torch.bfloat16,
|
|
@@ -99,7 +99,7 @@ class EndpointHandler:
|
|
| 99 |
# inputs = preprocess(model_inputs)
|
| 100 |
generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
|
| 101 |
generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 102 |
-
return {"
|
| 103 |
# return {"text":prediction[0]}
|
| 104 |
|
| 105 |
# @classmethod
|
|
|
|
| 59 |
# return response
|
| 60 |
|
| 61 |
class EndpointHandler:
|
| 62 |
+
def __init__(self,model_path:str):
|
| 63 |
# Load processor and model
|
| 64 |
self.PROCESSOR = AutoProcessor.from_pretrained(
|
| 65 |
+
model_path,
|
| 66 |
trust_remote_code=True,
|
| 67 |
# token=API_TOKEN,
|
| 68 |
)
|
| 69 |
self.MODEL = AutoModelForCausalLM.from_pretrained(
|
| 70 |
+
model_path,
|
| 71 |
# token=API_TOKEN,
|
| 72 |
trust_remote_code=True,
|
| 73 |
torch_dtype=torch.bfloat16,
|
|
|
|
| 99 |
# inputs = preprocess(model_inputs)
|
| 100 |
generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
|
| 101 |
generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 102 |
+
return {"text": generated_text}
|
| 103 |
# return {"text":prediction[0]}
|
| 104 |
|
| 105 |
# @classmethod
|