Spaces:
Runtime error
Runtime error
| import json | |
| from traceback import format_exc | |
| import flask_sock | |
| import hivemind | |
| import torch | |
| import config | |
| from app import sock, models | |
| from utils import safe_decode | |
| logger = hivemind.get_logger(__file__) | |
| def ws_api_generate(ws): | |
| try: | |
| request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT)) | |
| assert request["type"] == "open_inference_session" | |
| model_name = request.get("model") | |
| if model_name is None: | |
| model_name = config.DEFAULT_MODEL_NAME | |
| logger.info(f"ws.generate.open(), model={repr(model_name)}, max_length={repr(request['max_length'])}") | |
| model, tokenizer = models[model_name] | |
| with model.inference_session(max_length=request["max_length"]) as session: | |
| ws.send(json.dumps({"ok": True})) | |
| while True: | |
| request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT)) | |
| assert request["type"] == "generate" | |
| inputs = request.get("inputs") | |
| logger.info(f"ws.generate.step(), inputs={repr(inputs)}") | |
| if inputs is not None: | |
| inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE) | |
| n_input_tokens = inputs.shape[1] | |
| else: | |
| n_input_tokens = 0 | |
| stop_sequence = request.get("stop_sequence") | |
| extra_stop_sequences = request.get("extra_stop_sequences") | |
| if extra_stop_sequences is not None: | |
| cont_token = tokenizer(stop_sequence, return_tensors="pt")["input_ids"].to(config.DEVICE) | |
| assert cont_token.shape == (1, 1), \ | |
| "extra_stop_sequences require stop_sequence length to be exactly 1 token" | |
| all_outputs = '' | |
| delta_q = [] | |
| stop = False | |
| while not stop: | |
| outputs = model.generate( | |
| inputs=inputs, | |
| do_sample=request.get("do_sample", False), | |
| temperature=request.get("temperature", 1.0), | |
| top_k=request.get("top_k"), | |
| top_p=request.get("top_p"), | |
| max_length=request.get("max_length"), | |
| max_new_tokens=request.get("max_new_tokens"), | |
| session=session, | |
| ) | |
| delta = outputs[0, n_input_tokens:].tolist() | |
| outputs = safe_decode(tokenizer, torch.Tensor(delta_q + delta)) | |
| inputs = None # Inputs are passed only for the 1st token of the bot's response | |
| n_input_tokens = 0 | |
| combined = all_outputs + outputs | |
| stop = stop_sequence is None or combined.endswith(stop_sequence) | |
| if extra_stop_sequences is not None: | |
| for seq in extra_stop_sequences: | |
| if combined.endswith(seq): | |
| stop = True | |
| session.last_token_id = cont_token | |
| if not stop and outputs[-10:].find(u'\ufffd') > -1: | |
| # If there's a replacement character, keep getting more tokens | |
| # until we can decode properly | |
| delta_q = delta_q + delta | |
| logger.info(f"ws.generate.append_retry(), all_outputs={repr(combined)}") | |
| else: | |
| all_outputs = combined | |
| delta_q = [] | |
| logger.info(f"ws.generate.step(), all_outputs={repr(all_outputs)}, stop={stop}") | |
| ws.send(json.dumps({"ok": True, "outputs": outputs, "stop": stop})) | |
| except flask_sock.ConnectionClosed: | |
| pass | |
| except Exception: | |
| logger.warning("ws.generate failed:", exc_info=True) | |
| ws.send(json.dumps({"ok": False, "traceback": format_exc()})) | |
| finally: | |
| logger.info(f"ws.generate.close()") | |