Spaces:
Sleeping
Sleeping
Fix incorrect input length error using hash as data
Browse files- chatbot_constructor.py +4 -1
chatbot_constructor.py
CHANGED
|
@@ -29,7 +29,7 @@ def todset(text: str):
|
|
| 29 |
def hash_str(data: str):
|
| 30 |
return hashlib.md5(data.encode('utf-8')).hexdigest()
|
| 31 |
|
| 32 |
-
def train(message: str = "", epochs: int = 16, learning_rate: float = 0.001, emb_size: int = 128,
|
| 33 |
data_hash = None
|
| 34 |
if "→" not in data or "\n" not in data:
|
| 35 |
if data in os.listdir("cache"):
|
|
@@ -42,10 +42,13 @@ def train(message: str = "", epochs: int = 16, learning_rate: float = 0.001, emb
|
|
| 42 |
tokenizer.fit_on_texts(list(dset.keys()))
|
| 43 |
|
| 44 |
vocab_size = len(tokenizer.word_index) + 1
|
|
|
|
| 45 |
if data_hash is None:
|
| 46 |
data_hash = hash_str(data)+"_"+str(epochs)+"_"+str(learning_rate)+"_"+str(emb_size)+"_"+str(inp_len)+"_"+str(kernels_count)+"_"+str(kernel_size)+".keras"
|
| 47 |
elif message == "!getmodelhash":
|
| 48 |
return data_hash
|
|
|
|
|
|
|
| 49 |
if data_hash in os.listdir("cache"):
|
| 50 |
model = load_model("cache/"+data_hash)
|
| 51 |
else:
|
|
|
|
| 29 |
def hash_str(data: str):
|
| 30 |
return hashlib.md5(data.encode('utf-8')).hexdigest()
|
| 31 |
|
| 32 |
+
def train(message: str = "", epochs: int = 16, learning_rate: float = 0.001, emb_size: int = 128, input_len: int = 16, kernels_count: int = 8, kernel_size: int = 8, data: str = ""):
|
| 33 |
data_hash = None
|
| 34 |
if "→" not in data or "\n" not in data:
|
| 35 |
if data in os.listdir("cache"):
|
|
|
|
| 42 |
tokenizer.fit_on_texts(list(dset.keys()))
|
| 43 |
|
| 44 |
vocab_size = len(tokenizer.word_index) + 1
|
| 45 |
+
inp_len = input_len
|
| 46 |
if data_hash is None:
|
| 47 |
data_hash = hash_str(data)+"_"+str(epochs)+"_"+str(learning_rate)+"_"+str(emb_size)+"_"+str(inp_len)+"_"+str(kernels_count)+"_"+str(kernel_size)+".keras"
|
| 48 |
elif message == "!getmodelhash":
|
| 49 |
return data_hash
|
| 50 |
+
else:
|
| 51 |
+
inp_len = int(data_hash.split("_")[-3])
|
| 52 |
if data_hash in os.listdir("cache"):
|
| 53 |
model = load_model("cache/"+data_hash)
|
| 54 |
else:
|