import gradio as gr import onnxruntime as ort import json import torch # Load vocab with open("vocab.json", "r", encoding="utf-8") as f: vocab = json.load(f) inv_vocab = {i:tok for tok,i in vocab.items()} pad_idx = vocab.get("", 0) unk_idx = vocab.get("", 1) # Load ONNX model session = ort.InferenceSession("chat_model.onnx") # Tokenizer helper def tokenize(text): return [vocab.get(tok, unk_idx) for tok in text.split(" ")] def pad_sequence(seq, max_len=20): seq = seq + [pad_idx]*(max_len - len(seq)) return seq[:max_len] # Fungsi chat mini def chat_onnx(input_text, max_len=20): input_ids = pad_sequence(tokenize(input_text), max_len) input_tensor = np.array([input_ids], dtype=np.int64) output_ids = [] h = None # ONNX simple RNN ini biasanya stateless for _ in range(max_len): ort_inputs = {"input": input_tensor} ort_outs = session.run(None, ort_inputs) next_token = int(ort_outs[0][0, -1].argmax()) output_ids.append(next_token) input_tensor = np.array([[next_token]], dtype=np.int64) return " ".join([inv_vocab.get(i, "") for i in output_ids]) # Gradio interface iface = gr.Interface(fn=chat_onnx, inputs="text", outputs="text") iface.launch()