Robert Kenzo Medina Monsalve commited on
Commit
bb3ba7a
Β·
1 Parent(s): f2ef497

Deploying gru chatbot: only code w/o weights

Browse files
src/.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+ headless = true
3
+ enableCORS = false
src/00_prepare_cornell.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Descarga el corpus Cornell, extrae pares (Q,A) limpios
3
+ y guarda en `data/pairs.tsv` (tab-separated).
4
+ """
5
+
6
+ import os, zipfile, urllib.request, re, random, csv, json
7
+ from pathlib import Path
8
+
9
+ random.seed(42)
10
+ DATA_DIR = Path("data"); DATA_DIR.mkdir(exist_ok=True)
11
+ ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
12
+ ZIP_PATH = DATA_DIR/"cornell.zip"
13
+
14
+ if not (DATA_DIR/"cornell movie-dialogs corpus").exists():
15
+ print("β–Έ descargando corpus …")
16
+ urllib.request.urlretrieve(ZIP_URL, ZIP_PATH)
17
+ with zipfile.ZipFile(ZIP_PATH) as z: z.extractall(DATA_DIR)
18
+ ZIP_PATH.unlink()
19
+
20
+ BASE = DATA_DIR/"cornell movie-dialogs corpus"
21
+ lines_f = BASE/"movie_lines.txt"
22
+ conv_f = BASE/"movie_conversations.txt"
23
+
24
+ # ─── lines a diccionario ─────────────────────────────────────
25
+ id2line = {}
26
+ with open(lines_f, encoding="latin-1") as f:
27
+ for row in f:
28
+ _id, *_rest, txt = row.strip().split(" +++$+++ ")
29
+ id2line[_id] = txt
30
+
31
+ # ─── conversaciones β†’ pares Q,A ──────────────────────────────
32
+ pairs = []
33
+ with open(conv_f, encoding="latin-1") as f:
34
+ for row in f:
35
+ line_ids = eval(row.strip().split(" +++$+++ ")[-1])
36
+ for i in range(len(line_ids)-1):
37
+ q, a = id2line[line_ids[i]], id2line[line_ids[i+1]]
38
+ pairs.append((q, a))
39
+
40
+ # limpieza ligera
41
+ def norm(t:str)->str:
42
+ t = re.sub(r"[^a-zA-Z0-9.!?]+", " ", t.lower())
43
+ return re.sub(r"\s+", " ", t).strip()
44
+
45
+ pairs = [(norm(q), norm(a)) for q,a in pairs
46
+ if 2<=len(q.split())<=20 and 2<=len(a.split())<=20]
47
+
48
+ random.shuffle(pairs)
49
+ with open(DATA_DIR/"pairs.tsv","w",newline='',encoding="utf-8") as f:
50
+ wr = csv.writer(f, delimiter="\t")
51
+ wr.writerows(pairs)
52
+
53
+ print(f"Pairs listos β†’ {len(pairs):,} lΓ­neas.")
src/formatted_movie_lines_exporter.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # build_pairs.py ─ genera formatted_movie_lines.txt ───────────
2
+ import os, re, zipfile, urllib.request, csv, unicodedata
3
+
4
+ URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
5
+ ZIP = "cornell.zip"
6
+ ROOT = "cornell movie-dialogs corpus"
7
+ OUT = "formatted_movie_lines.txt" # ← lo que usamos despuΓ©s
8
+ MAX_SENT = 20 # descarta frases larguΓ­simas
9
+
10
+ def ascii(txt):
11
+ return "".join(c for c in unicodedata.normalize("NFD", txt)
12
+ if unicodedata.category(c) != "Mn")
13
+
14
+ def norm(s):
15
+ s = ascii(re.sub(r"[^a-zA-Z0-9?!.]+", " ", s.lower()))
16
+ s = re.sub(r"([?.!])", r" \1 ", s)
17
+ return re.sub(r"\s+", " ", s).strip()
18
+
19
+ # ─── descarga y des-zip ──────────────────────────────────────
20
+ if not os.path.isdir(ROOT):
21
+ print("⏬ descargando corpus…")
22
+ urllib.request.urlretrieve(URL, ZIP)
23
+ with zipfile.ZipFile(ZIP) as z: z.extractall()
24
+ os.remove(ZIP)
25
+
26
+ # ─── lee lΓ­neas y conversaciones ─────────────────────────────
27
+ print("πŸ”§ procesando…")
28
+ lines = {}
29
+ with open(os.path.join(ROOT,"movie_lines.txt"),encoding="latin-1") as f:
30
+ for ln in f:
31
+ parts = ln.strip().split(" +++$+++ ")
32
+ lines[parts[0]] = norm(parts[-1])
33
+
34
+ pairs = []
35
+ with open(os.path.join(ROOT,"movie_conversations.txt"),
36
+ encoding="latin-1") as f:
37
+ for conv in f:
38
+ ids = eval(conv.strip().split(" +++$+++ ")[-1])
39
+ for a,b in zip(ids,ids[1:]):
40
+ q, r = lines[a], lines[b]
41
+ if (2<=len(q.split())<MAX_SENT and
42
+ 2<=len(r.split())<MAX_SENT):
43
+ pairs.append((q,r))
44
+
45
+ # ─── guarda en TSV (pregunta[TAB]respuesta) ──────────────────
46
+ with open(OUT,"w",encoding="utf-8",newline="") as f:
47
+ wr = csv.writer(f,delimiter='\t')
48
+ wr.writerows(pairs)
49
+
50
+ print(f"βœ… creado {OUT} con {len(pairs):,} pares")
src/requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ altair==5.5.0
3
+ astunparse==1.6.3
4
+ attrs==25.3.0
5
+ blinker==1.9.0
6
+ cachetools==5.5.2
7
+ certifi==2025.4.26
8
+ charset-normalizer==3.4.2
9
+ click==8.2.1
10
+ colorama==0.4.6
11
+ contourpy==1.3.2
12
+ cycler==0.12.1
13
+ flatbuffers==25.2.10
14
+ fonttools==4.58.0
15
+ gast==0.6.0
16
+ gitdb==4.0.12
17
+ GitPython==3.1.44
18
+ google-auth==2.40.2
19
+ google-auth-oauthlib==1.2.2
20
+ google-pasta==0.2.0
21
+ grpcio==1.71.0
22
+ h5py==3.13.0
23
+ idna==3.10
24
+ Jinja2==3.1.6
25
+ jsonschema==4.23.0
26
+ jsonschema-specifications==2025.4.1
27
+ kiwisolver==1.4.8
28
+ libclang==18.1.1
29
+ Markdown==3.8
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==3.0.2
32
+ matplotlib==3.7.5
33
+ mdurl==0.1.2
34
+ ml-dtypes==0.2.0
35
+ narwhals==1.40.0
36
+ numpy==1.23.5
37
+ oauthlib==3.2.2
38
+ opt_einsum==3.4.0
39
+ packaging==24.2
40
+ pandas==2.2.3
41
+ pillow==10.4.0
42
+ protobuf==4.25.7
43
+ pyarrow==20.0.0
44
+ pyasn1==0.6.1
45
+ pyasn1-modules==0.4.2
46
+ pydeck==0.9.1
47
+ Pygments==2.19.1
48
+ pyparsing==3.2.3
49
+ python-dateutil==2.9.0.post0
50
+ pytz==2025.2
51
+ referencing==0.36.2
52
+ requests==2.32.3
53
+ requests-oauthlib==2.0.0
54
+ rich==13.9.4
55
+ rpds-py==0.25.1
56
+ rsa==4.9.1
57
+ setuptools==65.5.1
58
+ six==1.17.0
59
+ smmap==5.0.2
60
+ streamlit==1.33.0
61
+ tenacity==8.5.0
62
+ tensorboard==2.15.2
63
+ tensorboard-data-server==0.7.2
64
+ tensorflow-io-gcs-filesystem==0.31.0
65
+ termcolor==3.1.0
66
+ toml==0.10.2
67
+ tornado==6.5.1
68
+ typing_extensions==4.13.2
69
+ tzdata==2025.2
70
+ urllib3==2.4.0
71
+ watchdog==6.0.0
72
+ Werkzeug==3.1.3
73
+ wheel==0.38.4
74
+ wrapt==1.14.1
src/serve_gru.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # serve_gru.py ────────────────────────────────────────────────
2
+ import re, numpy as np, tensorflow as tf
3
+ from tensorflow.keras.models import load_model
4
+ from tensorflow.keras.preprocessing.text import tokenizer_from_json
5
+
6
+ MODEL_PATH, TOK_PATH = "chatbot_seq2seq.keras", "tokenizer.json"
7
+ MAXLEN = 22
8
+ START, END = "<start>", "<end>"
9
+
10
+ # ── utilidades ------------------------------------------------
11
+ def _norm(s: str) -> str:
12
+ s = re.sub(r"[^a-zA-Z0-9?!.]+", " ", s.lower())
13
+ s = re.sub(r"([?.!])", r" \1 ", s)
14
+ return re.sub(r"\s+", " ", s).strip()
15
+
16
+ def _pad(seq):
17
+ return tf.keras.preprocessing.sequence.pad_sequences(
18
+ seq, maxlen=MAXLEN, padding="post"
19
+ )
20
+
21
+ # ── carga modelo y tokenizer ----------------------------------
22
+ print("β€£ cargando modelo y tokenizer…", end="", flush=True)
23
+ model = load_model(MODEL_PATH)
24
+ with open(TOK_PATH, encoding="utf-8") as f:
25
+ tok = tokenizer_from_json(f.read())
26
+
27
+ emb_layer = model.get_layer("emb")
28
+ enc_gru = model.get_layer("enc_gru")
29
+ dec_gru = model.get_layer("dec_gru")
30
+ dense = model.get_layer("dense")
31
+
32
+ enc_model = tf.keras.Model(model.input[0], enc_gru.output[1])
33
+ dec_cell = dec_gru.cell
34
+
35
+ UNK_ID = tok.word_index["<unk>"]
36
+ START_ID = tok.word_index[START]
37
+ END_ID = tok.word_index[END]
38
+
39
+ print(" listo 🟒")
40
+
41
+ # ── paso ΓΊnico del decoder ------------------------------------
42
+ def _step(tok_id, state):
43
+ # token β†’ embedding
44
+ x = tf.constant([[tok_id]], dtype=tf.int32) # (1,1)
45
+ x = emb_layer(x) # (1,1,emb)
46
+ x = tf.squeeze(x, axis=1) # (1,emb)
47
+ h, _ = dec_cell(x, states=state) # (1,units)
48
+ logits = dense(h)[0].numpy() # (vocab,)
49
+ logits[UNK_ID] = -1e9 # nunca <unk>
50
+ return logits, [h]
51
+
52
+ # ── funciΓ³n de inferencia greedy -----------------------------
53
+ def reply(msg: str, max_len: int = MAXLEN) -> str:
54
+ # normaliza y codifica
55
+ seq = _pad(tok.texts_to_sequences([f"{START} {_norm(msg)} {END}"]))
56
+ h_enc = enc_model.predict(seq, verbose=0) # (1,units)
57
+ state = [tf.convert_to_tensor(h_enc)] # [(1,units)]
58
+
59
+ tok_id, out_ids = START_ID, []
60
+ for _ in range(max_len):
61
+ logits, state = _step(tok_id, state)
62
+ # greedy: la mΓ‘s probable
63
+ tok_id = int(np.argmax(logits))
64
+
65
+ # condiciones de parada
66
+ if tok_id in (END_ID, START_ID):
67
+ break
68
+ if len(out_ids) >= 2 and tok_id == out_ids[-1] == out_ids[-2]:
69
+ break
70
+
71
+ out_ids.append(tok_id)
72
+
73
+ # reconstruye texto
74
+ return " ".join(tok.index_word[i] for i in out_ids) or "(sin respuesta)"
75
+
76
+ # ── demo CLI (opcional) ---------------------------------------
77
+ if __name__ == "__main__":
78
+ while True:
79
+ q = input("TΓΊ: ").strip()
80
+ if not q: continue
81
+ print("Bot:", reply(q))
src/streamlit_app.py CHANGED
@@ -1,40 +1,22 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from serve_gru import reply
3
+
4
+ st.set_page_config(page_title="Chatbot GRU", page_icon="πŸ€–")
5
+ st.title("πŸ’¬ Chatbot GRU (Cornell Movie Dialogs)")
6
+
7
+ # Inicializa historial
8
+ if "history" not in st.session_state:
9
+ st.session_state.history = []
10
+
11
+ # Campo de chat integrado
12
+ msg = st.chat_input("Escribe tu mensaje...")
13
+ if msg:
14
+ # AΓ±ade mensaje del usuario
15
+ st.session_state.history.append(("user", msg))
16
+ # Obtiene respuesta del modelo
17
+ bot_resp = reply(msg)
18
+ st.session_state.history.append(("assistant", bot_resp))
19
+
20
+ # Renderiza el chat
21
+ for role, text in st.session_state.history:
22
+ st.chat_message(role).markdown(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff