Spaces:
Runtime error
Runtime error
Commit
·
8f3dc34
1
Parent(s):
8cb62ff
Update with h2oGPT hash 3e927fb6330dd3d1256b47eb201bd376230dd20a
Browse files- generate.py +3 -7
- utils.py +0 -50
generate.py
CHANGED
|
@@ -3,8 +3,9 @@ import sys
|
|
| 3 |
import os
|
| 4 |
import traceback
|
| 5 |
import typing
|
|
|
|
| 6 |
|
| 7 |
-
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext,
|
| 8 |
|
| 9 |
SEED = 1236
|
| 10 |
set_seed(SEED)
|
|
@@ -828,15 +829,10 @@ def evaluate(
|
|
| 828 |
skip_prompt = False
|
| 829 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
| 830 |
gen_kwargs.update(dict(streamer=streamer))
|
| 831 |
-
if debug:
|
| 832 |
-
KThread.show_threads()
|
| 833 |
target_func = generate_with_exceptions
|
| 834 |
-
if concurrency_count == 1:
|
| 835 |
-
# otherwise can't do this
|
| 836 |
-
KThread.kill_threads(target_func.__name__, debug=debug)
|
| 837 |
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
| 838 |
raise_generate_gpu_exceptions, **gen_kwargs)
|
| 839 |
-
thread =
|
| 840 |
thread.start()
|
| 841 |
outputs = ""
|
| 842 |
for new_text in streamer:
|
|
|
|
| 3 |
import os
|
| 4 |
import traceback
|
| 5 |
import typing
|
| 6 |
+
from threading import Thread
|
| 7 |
|
| 8 |
+
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
| 9 |
|
| 10 |
SEED = 1236
|
| 11 |
set_seed(SEED)
|
|
|
|
| 829 |
skip_prompt = False
|
| 830 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
| 831 |
gen_kwargs.update(dict(streamer=streamer))
|
|
|
|
|
|
|
| 832 |
target_func = generate_with_exceptions
|
|
|
|
|
|
|
|
|
|
| 833 |
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
| 834 |
raise_generate_gpu_exceptions, **gen_kwargs)
|
| 835 |
+
thread = Thread(target=target)
|
| 836 |
thread.start()
|
| 837 |
outputs = ""
|
| 838 |
for new_text in streamer:
|
utils.py
CHANGED
|
@@ -244,56 +244,6 @@ class NullContext(threading.local):
|
|
| 244 |
pass
|
| 245 |
|
| 246 |
|
| 247 |
-
class KThread(threading.Thread):
|
| 248 |
-
"""Thread with a kill method."""
|
| 249 |
-
|
| 250 |
-
def __init__(self, *args, **keywords):
|
| 251 |
-
threading.Thread.__init__(self, *args, **keywords)
|
| 252 |
-
self.killed = False
|
| 253 |
-
|
| 254 |
-
def start(self):
|
| 255 |
-
"""Start the thread."""
|
| 256 |
-
self.__run_backup = self.run
|
| 257 |
-
self.run = self.__run # Force the Thread to install our trace.
|
| 258 |
-
threading.Thread.start(self)
|
| 259 |
-
|
| 260 |
-
def __run(self):
|
| 261 |
-
"""install trace."""
|
| 262 |
-
sys.settrace(self.globaltrace)
|
| 263 |
-
self.__run_backup()
|
| 264 |
-
self.run = self.__run_backup
|
| 265 |
-
|
| 266 |
-
def globaltrace(self, frame, why, arg):
|
| 267 |
-
if why == 'call':
|
| 268 |
-
return self.localtrace
|
| 269 |
-
else:
|
| 270 |
-
return None
|
| 271 |
-
|
| 272 |
-
def localtrace(self, frame, why, arg):
|
| 273 |
-
if self.killed:
|
| 274 |
-
if why == 'line':
|
| 275 |
-
raise SystemExit()
|
| 276 |
-
return self.localtrace
|
| 277 |
-
|
| 278 |
-
def kill(self):
|
| 279 |
-
self.killed = True
|
| 280 |
-
|
| 281 |
-
@staticmethod
|
| 282 |
-
def show_threads():
|
| 283 |
-
for thread in threading.enumerate():
|
| 284 |
-
print(thread.name, flush=True)
|
| 285 |
-
|
| 286 |
-
@staticmethod
|
| 287 |
-
def kill_threads(name, debug=False):
|
| 288 |
-
for thread in threading.enumerate():
|
| 289 |
-
if name in thread.name:
|
| 290 |
-
if debug:
|
| 291 |
-
print("Trying to kill %s %s" % (thread.ident, thread), flush=True)
|
| 292 |
-
thread.kill()
|
| 293 |
-
if debug:
|
| 294 |
-
print(thread, flush=True)
|
| 295 |
-
|
| 296 |
-
|
| 297 |
def wrapped_partial(func, *args, **kwargs):
|
| 298 |
"""
|
| 299 |
Give partial properties of normal function, like __name__ attribute etc.
|
|
|
|
| 244 |
pass
|
| 245 |
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
def wrapped_partial(func, *args, **kwargs):
|
| 248 |
"""
|
| 249 |
Give partial properties of normal function, like __name__ attribute etc.
|