gptj-test / utils.py
varun500's picture
Update utils.py
e3b39f3
raw
history blame contribute delete
742 Bytes
import boto3
import json
# Create a low-level client representing Amazon SageMaker Runtime
session = boto3.Session()
sagemaker_runtime = session.client('sagemaker-runtime', region_name="us-east-1")
# The name of the endpoint. The name must be unique within an AWS Region in your AWS account.
endpoint_name = 'sm-endpoint-gpt-j-6b'
def generate_text(prompt, params):
payload = {"inputs": prompt, "parameters": params}
response = sagemaker_runtime.invoke_endpoint(
EndpointName=endpoint_name,
ContentType='application/json',
Body=json.dumps(payload)
)
result = json.loads(response['Body'].read().decode())
text = result[0]['generated_text']
return text