| | import json |
| | import numpy as np |
| | import gradio as gr |
| | import onnxruntime as ort |
| |
|
| | |
| | with open("vocab.json") as f: |
| | vocab = json.load(f) |
| | token2id = vocab |
| | id2token = {v:k for k,v in vocab.items()} |
| |
|
| | |
| | session = ort.InferenceSession("chat_model.onnx") |
| |
|
| | def tokenize(text): |
| | return [token2id.get(ch, 0) for ch in text] |
| |
|
| | def detokenize(ids): |
| | return "".join([id2token.get(i, "?") for i in ids]) |
| |
|
| | def predict(text): |
| | ids = tokenize(text) |
| | x = np.array([ids], dtype=np.int64) |
| |
|
| | |
| | input_name = session.get_inputs()[0].name |
| | output_name = session.get_outputs()[0].name |
| |
|
| | |
| | output = session.run([output_name], {input_name: x})[0][0] |
| |
|
| | |
| | return detokenize(output.tolist()) |
| |
|
| | |
| | demo = gr.Interface(fn=predict, inputs="text", outputs="text") |
| | demo.launch() |