Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -56,7 +56,7 @@ def init_model_from(url, filename):
|
|
| 56 |
ckpt_path = Path(out_dir) / filename
|
| 57 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| 58 |
if not os.path.exists(ckpt_path):
|
| 59 |
-
gr.Info('Downloading model...')
|
| 60 |
download_file(url, ckpt_path)
|
| 61 |
gr.Info('✅Model downloaded successfully.', duration=2)
|
| 62 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
@@ -70,24 +70,24 @@ def init_model_from(url, filename):
|
|
| 70 |
model.load_state_dict(state_dict)
|
| 71 |
return model
|
| 72 |
|
| 73 |
-
def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k):
|
|
|
|
| 74 |
x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
|
| 75 |
with torch.no_grad():
|
| 76 |
for k in range(samples):
|
|
|
|
| 77 |
generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
| 78 |
|
| 79 |
output = decode(generated[0].tolist())
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
match_botoutput = re.search(r'<human>(.*?)<', output)
|
| 82 |
-
match_emotion = re.search(r'<emotion>\s*(.*?)\s*<', output)
|
| 83 |
-
match_context = re.search(r'<context>\s*(.*?)\s*<', output)
|
| 84 |
response = ''
|
| 85 |
-
emotion = ''
|
| 86 |
-
context = ''
|
| 87 |
if match_botoutput:
|
| 88 |
try :
|
| 89 |
-
response = match_botoutput.group(1).
|
| 90 |
except:
|
| 91 |
-
response =
|
| 92 |
#return response, emotion, context
|
| 93 |
-
return [input, response]
|
|
|
|
| 56 |
ckpt_path = Path(out_dir) / filename
|
| 57 |
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
| 58 |
if not os.path.exists(ckpt_path):
|
| 59 |
+
gr.Info('Downloading model...',duration=10)
|
| 60 |
download_file(url, ckpt_path)
|
| 61 |
gr.Info('✅Model downloaded successfully.', duration=2)
|
| 62 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
|
|
| 70 |
model.load_state_dict(state_dict)
|
| 71 |
return model
|
| 72 |
|
| 73 |
+
def respond(input, samples, model, encode, decode, max_new_tokens,temperature, top_k):
|
| 74 |
+
input = "<bot> " + input
|
| 75 |
x = (torch.tensor(encode(input), dtype=torch.long, device=device)[None, ...])
|
| 76 |
with torch.no_grad():
|
| 77 |
for k in range(samples):
|
| 78 |
+
|
| 79 |
generated = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
| 80 |
|
| 81 |
output = decode(generated[0].tolist())
|
| 82 |
+
# if input in output:
|
| 83 |
+
# output = output.split(input)[-1].strip() # Take the part after `<input>`
|
| 84 |
|
| 85 |
+
match_botoutput = re.search(r'<human>(.*?)<', output, re.DOTALL)
|
|
|
|
|
|
|
| 86 |
response = ''
|
|
|
|
|
|
|
| 87 |
if match_botoutput:
|
| 88 |
try :
|
| 89 |
+
response = match_botoutput.group(1).strip()
|
| 90 |
except:
|
| 91 |
+
response = ''
|
| 92 |
#return response, emotion, context
|
| 93 |
+
return [input, response, output]
|