chatbot-gru / src /serve_gru.py
robertkm23's picture
Update src/serve_gru.py
88e5d68 verified
#from huggingface_hub import hf_hub_download
import re, numpy as np, tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import tokenizer_from_json
import os, requests
MODEL_URL = "https://huggingface.co/robertkm23/chat_bot/resolve/main/chatbot_seq2seq.keras"
MODEL_PATH = "/tmp/chatbot_seq2seq.keras"
if not os.path.exists(MODEL_PATH):
with requests.get(MODEL_URL, stream=True) as r:
r.raise_for_status()
with open(MODEL_PATH, "wb") as f:
for chunk in r.iter_content(4*1024*1024):
f.write(chunk)
# igual para tokenizer.json…
TOK_URL = "https://huggingface.co/robertkm23/chat_bot/resolve/main/tokenizer.json"
TOK_PATH = "/tmp/tokenizer.json"
if not os.path.exists(TOK_PATH):
with requests.get(TOK_URL, stream=True) as r:
r.raise_for_status()
with open(TOK_PATH, "wb") as f:
for chunk in r.iter_content(1*1024*1024):
f.write(chunk)
# luego cargas
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import tokenizer_from_json
model = load_model(MODEL_PATH)
with open(TOK_PATH, "r", encoding="utf-8") as f:
tok = tokenizer_from_json(f.read())
MAXLEN = 22
START, END = "<start>", "<end>"
# ── utilidades ------------------------------------------------
def _norm(s: str) -> str:
s = re.sub(r"[^a-zA-Z0-9?!.]+", " ", s.lower())
s = re.sub(r"([?.!])", r" \1 ", s)
return re.sub(r"\s+", " ", s).strip()
def _pad(seq):
return tf.keras.preprocessing.sequence.pad_sequences(
seq, maxlen=MAXLEN, padding="post"
)
# ── carga modelo y tokenizer ----------------------------------
print("‣ cargando modelo y tokenizer…", end="", flush=True)
model = load_model(MODEL_PATH)
with open(TOK_PATH, encoding="utf-8") as f:
tok = tokenizer_from_json(f.read())
emb_layer = model.get_layer("emb")
enc_gru = model.get_layer("enc_gru")
dec_gru = model.get_layer("dec_gru")
dense = model.get_layer("dense")
enc_model = tf.keras.Model(model.input[0], enc_gru.output[1])
dec_cell = dec_gru.cell
UNK_ID = tok.word_index["<unk>"]
START_ID = tok.word_index[START]
END_ID = tok.word_index[END]
print(" listo 🟢")
# ── paso único del decoder ------------------------------------
def _step(tok_id, state):
# token → embedding
x = tf.constant([[tok_id]], dtype=tf.int32) # (1,1)
x = emb_layer(x) # (1,1,emb)
x = tf.squeeze(x, axis=1) # (1,emb)
h, _ = dec_cell(x, states=state) # (1,units)
logits = dense(h)[0].numpy() # (vocab,)
logits[UNK_ID] = -1e9 # nunca <unk>
return logits, [h]
# ── función de inferencia greedy -----------------------------
def reply(msg: str, max_len: int = MAXLEN) -> str:
# normaliza y codifica
seq = _pad(tok.texts_to_sequences([f"{START} {_norm(msg)} {END}"]))
h_enc = enc_model.predict(seq, verbose=0) # (1,units)
state = [tf.convert_to_tensor(h_enc)] # [(1,units)]
tok_id, out_ids = START_ID, []
for _ in range(max_len):
logits, state = _step(tok_id, state)
# greedy: la más probable
tok_id = int(np.argmax(logits))
# condiciones de parada
if tok_id in (END_ID, START_ID):
break
if len(out_ids) >= 2 and tok_id == out_ids[-1] == out_ids[-2]:
break
out_ids.append(tok_id)
# reconstruye texto
return " ".join(tok.index_word[i] for i in out_ids) or "(sin respuesta)"
# ── demo CLI (opcional) ---------------------------------------
if __name__ == "__main__":
while True:
q = input("Tú: ").strip()
if not q: continue
print("Bot:", reply(q))