Commit
·
d19a591
1
Parent(s):
fa994ec
__call__ return values
Browse files- handler.py +18 -4
handler.py
CHANGED
|
@@ -13,9 +13,7 @@ class EndpointHandler:
|
|
| 13 |
self.model = AutoModel.from_pretrained(
|
| 14 |
model_dir,
|
| 15 |
torch_dtype=torch.bfloat16,
|
| 16 |
-
low_cpu_mem_usage=True,
|
| 17 |
trust_remote_code=True,
|
| 18 |
-
device_map="auto",
|
| 19 |
).eval()
|
| 20 |
|
| 21 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -23,9 +21,25 @@ class EndpointHandler:
|
|
| 23 |
)
|
| 24 |
|
| 25 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
| 26 |
-
logger.info(f"Received incoming request with {data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
if __name__ == "__main__":
|
| 30 |
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
|
| 31 |
-
print(handler)
|
|
|
|
| 13 |
self.model = AutoModel.from_pretrained(
|
| 14 |
model_dir,
|
| 15 |
torch_dtype=torch.bfloat16,
|
|
|
|
| 16 |
trust_remote_code=True,
|
|
|
|
| 17 |
).eval()
|
| 18 |
|
| 19 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
| 24 |
+
logger.info(f"Received incoming request with {data}")
|
| 25 |
+
|
| 26 |
+
# Extract input text from the request data
|
| 27 |
+
input_text = data.get("inputs", "")
|
| 28 |
+
if not input_text:
|
| 29 |
+
logger.warning("No input text provided")
|
| 30 |
+
return [{"generated_text": ""}] # Return empty result but in valid format
|
| 31 |
+
|
| 32 |
+
# Tokenize the input
|
| 33 |
+
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
|
| 34 |
+
|
| 35 |
+
# Generate embeddings
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
outputs = self.model(**inputs)
|
| 38 |
+
|
| 39 |
+
# Process outputs - this depends on your specific model and requirements
|
| 40 |
+
# For now, we'll just return the input as the output to fix the array format issue
|
| 41 |
+
return [{"input_text": input_text, "generated_text": outputs}]
|
| 42 |
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
| 45 |
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
|
|
|