robertkm23 commited on
Commit
e6434d7
Β·
verified Β·
1 Parent(s): c772beb

Update serve_gru.py

Browse files
Files changed (1) hide show
  1. serve_gru.py +83 -90
serve_gru.py CHANGED
@@ -1,90 +1,83 @@
1
- # serve_gru.py ────────────────────────────────────────────────
2
- import re, numpy as np, tensorflow as tf
3
- from tensorflow.keras.models import load_model
4
- from tensorflow.keras.preprocessing.text import tokenizer_from_json
5
- from huggingface_hub import hf_hub_download
6
-
7
- # --- descarga desde tu Space/repo de HF ---
8
- MODEL_PATH = hf_hub_download(
9
- repo_id="robertkm23/chat_bot", filename="chatbot_seq2seq.keras",
10
- repo_type="model"
11
- )
12
- TOK_PATH = hf_hub_download(
13
- repo_id="robertkm23/chat_bot", filename="tokenizer.json",
14
- repo_type="model"
15
- )
16
- MAXLEN = 22
17
- START, END = "<start>", "<end>"
18
-
19
- # ── utilidades ------------------------------------------------
20
- def _norm(s: str) -> str:
21
- s = re.sub(r"[^a-zA-Z0-9?!.]+", " ", s.lower())
22
- s = re.sub(r"([?.!])", r" \1 ", s)
23
- return re.sub(r"\s+", " ", s).strip()
24
-
25
- def _pad(seq):
26
- return tf.keras.preprocessing.sequence.pad_sequences(
27
- seq, maxlen=MAXLEN, padding="post"
28
- )
29
-
30
- # ── carga modelo y tokenizer ----------------------------------
31
- print("β€£ cargando modelo y tokenizer…", end="", flush=True)
32
- model = load_model(MODEL_PATH)
33
- with open(TOK_PATH, encoding="utf-8") as f:
34
- tok = tokenizer_from_json(f.read())
35
-
36
- emb_layer = model.get_layer("emb")
37
- enc_gru = model.get_layer("enc_gru")
38
- dec_gru = model.get_layer("dec_gru")
39
- dense = model.get_layer("dense")
40
-
41
- enc_model = tf.keras.Model(model.input[0], enc_gru.output[1])
42
- dec_cell = dec_gru.cell
43
-
44
- UNK_ID = tok.word_index["<unk>"]
45
- START_ID = tok.word_index[START]
46
- END_ID = tok.word_index[END]
47
-
48
- print(" listo 🟒")
49
-
50
- # ── paso ΓΊnico del decoder ------------------------------------
51
- def _step(tok_id, state):
52
- # token β†’ embedding
53
- x = tf.constant([[tok_id]], dtype=tf.int32) # (1,1)
54
- x = emb_layer(x) # (1,1,emb)
55
- x = tf.squeeze(x, axis=1) # (1,emb)
56
- h, _ = dec_cell(x, states=state) # (1,units)
57
- logits = dense(h)[0].numpy() # (vocab,)
58
- logits[UNK_ID] = -1e9 # nunca <unk>
59
- return logits, [h]
60
-
61
- # ── funciΓ³n de inferencia greedy -----------------------------
62
- def reply(msg: str, max_len: int = MAXLEN) -> str:
63
- # normaliza y codifica
64
- seq = _pad(tok.texts_to_sequences([f"{START} {_norm(msg)} {END}"]))
65
- h_enc = enc_model.predict(seq, verbose=0) # (1,units)
66
- state = [tf.convert_to_tensor(h_enc)] # [(1,units)]
67
-
68
- tok_id, out_ids = START_ID, []
69
- for _ in range(max_len):
70
- logits, state = _step(tok_id, state)
71
- # greedy: la mΓ‘s probable
72
- tok_id = int(np.argmax(logits))
73
-
74
- # condiciones de parada
75
- if tok_id in (END_ID, START_ID):
76
- break
77
- if len(out_ids) >= 2 and tok_id == out_ids[-1] == out_ids[-2]:
78
- break
79
-
80
- out_ids.append(tok_id)
81
-
82
- # reconstruye texto
83
- return " ".join(tok.index_word[i] for i in out_ids) or "(sin respuesta)"
84
-
85
- # ── demo CLI (opcional) ---------------------------------------
86
- if __name__ == "__main__":
87
- while True:
88
- q = input("TΓΊ: ").strip()
89
- if not q: continue
90
- print("Bot:", reply(q))
 
1
+ # serve_gru.py ────────────────────────────────────────────────
2
+ import re, numpy as np, tensorflow as tf
3
+ from tensorflow.keras.models import load_model
4
+ from tensorflow.keras.preprocessing.text import tokenizer_from_json
5
+ # from huggingface_hub import hf_hub_download
6
+
7
+ # --- descarga desde tu Space/repo de HF ---
8
+ MODEL_PATH, TOK_PATH = "chatbot_seq2seq.keras", "tokenizer.json"
9
+ MAXLEN = 22
10
+ START, END = "<start>", "<end>"
11
+
12
+ # ── utilidades ------------------------------------------------
13
+ def _norm(s: str) -> str:
14
+ s = re.sub(r"[^a-zA-Z0-9?!.]+", " ", s.lower())
15
+ s = re.sub(r"([?.!])", r" \1 ", s)
16
+ return re.sub(r"\s+", " ", s).strip()
17
+
18
+ def _pad(seq):
19
+ return tf.keras.preprocessing.sequence.pad_sequences(
20
+ seq, maxlen=MAXLEN, padding="post"
21
+ )
22
+
23
+ # ── carga modelo y tokenizer ----------------------------------
24
+ print("β€£ cargando modelo y tokenizer…", end="", flush=True)
25
+ model = load_model(MODEL_PATH)
26
+ with open(TOK_PATH, encoding="utf-8") as f:
27
+ tok = tokenizer_from_json(f.read())
28
+
29
+ emb_layer = model.get_layer("emb")
30
+ enc_gru = model.get_layer("enc_gru")
31
+ dec_gru = model.get_layer("dec_gru")
32
+ dense = model.get_layer("dense")
33
+
34
+ enc_model = tf.keras.Model(model.input[0], enc_gru.output[1])
35
+ dec_cell = dec_gru.cell
36
+
37
+ UNK_ID = tok.word_index["<unk>"]
38
+ START_ID = tok.word_index[START]
39
+ END_ID = tok.word_index[END]
40
+
41
+ print(" listo 🟒")
42
+
43
+ # ── paso ΓΊnico del decoder ------------------------------------
44
+ def _step(tok_id, state):
45
+ # token β†’ embedding
46
+ x = tf.constant([[tok_id]], dtype=tf.int32) # (1,1)
47
+ x = emb_layer(x) # (1,1,emb)
48
+ x = tf.squeeze(x, axis=1) # (1,emb)
49
+ h, _ = dec_cell(x, states=state) # (1,units)
50
+ logits = dense(h)[0].numpy() # (vocab,)
51
+ logits[UNK_ID] = -1e9 # nunca <unk>
52
+ return logits, [h]
53
+
54
+ # ── funciΓ³n de inferencia greedy -----------------------------
55
+ def reply(msg: str, max_len: int = MAXLEN) -> str:
56
+ # normaliza y codifica
57
+ seq = _pad(tok.texts_to_sequences([f"{START} {_norm(msg)} {END}"]))
58
+ h_enc = enc_model.predict(seq, verbose=0) # (1,units)
59
+ state = [tf.convert_to_tensor(h_enc)] # [(1,units)]
60
+
61
+ tok_id, out_ids = START_ID, []
62
+ for _ in range(max_len):
63
+ logits, state = _step(tok_id, state)
64
+ # greedy: la mΓ‘s probable
65
+ tok_id = int(np.argmax(logits))
66
+
67
+ # condiciones de parada
68
+ if tok_id in (END_ID, START_ID):
69
+ break
70
+ if len(out_ids) >= 2 and tok_id == out_ids[-1] == out_ids[-2]:
71
+ break
72
+
73
+ out_ids.append(tok_id)
74
+
75
+ # reconstruye texto
76
+ return " ".join(tok.index_word[i] for i in out_ids) or "(sin respuesta)"
77
+
78
+ # ── demo CLI (opcional) ---------------------------------------
79
+ if __name__ == "__main__":
80
+ while True:
81
+ q = input("TΓΊ: ").strip()
82
+ if not q: continue
83
+ print("Bot:", reply(q))