codys12 commited on
Commit
e2d63b2
·
1 Parent(s): c2f55a1

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +1 -1
handler.py CHANGED
@@ -10,6 +10,7 @@ LOGGER = logging.getLogger(__name__)
10
  logging.basicConfig(level=logging.INFO)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
 
13
 
14
  class EndpointHandler():
15
  def __init__(self, path=""):
@@ -102,7 +103,6 @@ def generate(
102
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
103
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
104
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
105
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
106
  input_ids = input_ids.to(model.device)
107
 
108
  # Generate the response
 
10
  logging.basicConfig(level=logging.INFO)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ MAX_INPUT_TOKEN_LENGTH = 16000
14
 
15
  class EndpointHandler():
16
  def __init__(self, path=""):
 
103
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
104
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
105
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
106
  input_ids = input_ids.to(model.device)
107
 
108
  # Generate the response