manthilaffs commited on
Commit
aaa6d06
·
verified ·
1 Parent(s): 1925104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -15,7 +15,23 @@ alpaca_prompt = """පහත දැක්වෙන්නේ යම් කාර
15
  ### ප්‍රතිචාරය:
16
  {}"""
17
 
18
- def infer_stream(message, history, enable_history=False, max_new_tokens=512):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  global model, tokenizer
20
 
21
  if model is None:
@@ -45,18 +61,8 @@ def infer_stream(message, history, enable_history=False, max_new_tokens=512):
45
 
46
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47
 
48
- # Setup streaming
49
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
50
-
51
- generation_kwargs = dict(
52
- **inputs,
53
- max_new_tokens=max_new_tokens,
54
- streamer=streamer,
55
- )
56
-
57
- # Start generation in a separate thread
58
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
59
- thread.start()
60
 
61
  # Stream the output
62
  partial_text = ""
@@ -75,11 +81,6 @@ def infer_stream(message, history, enable_history=False, max_new_tokens=512):
75
 
76
  thread.join()
77
 
78
- @spaces.GPU
79
- def infer(message, history, enable_history=False, max_new_tokens=512):
80
- # Return the generator for streaming
81
- return infer_stream(message, history, enable_history, max_new_tokens)
82
-
83
  # Custom CSS for styling
84
  custom_css = """
85
  #splash-screen {
 
15
  ### ප්‍රතිචාරය:
16
  {}"""
17
 
18
+ @spaces.GPU
19
+ def generate_with_streaming(inputs, max_new_tokens):
20
+ global model
21
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
22
+
23
+ generation_kwargs = dict(
24
+ **inputs,
25
+ max_new_tokens=max_new_tokens,
26
+ streamer=streamer,
27
+ )
28
+
29
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
30
+ thread.start()
31
+
32
+ return streamer, thread
33
+
34
+ def infer(message, history, enable_history=False, max_new_tokens=512):
35
  global model, tokenizer
36
 
37
  if model is None:
 
61
 
62
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
 
64
+ # Get streamer and thread from GPU function
65
+ streamer, thread = generate_with_streaming(inputs, max_new_tokens)
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Stream the output
68
  partial_text = ""
 
81
 
82
  thread.join()
83
 
 
 
 
 
 
84
  # Custom CSS for styling
85
  custom_css = """
86
  #splash-screen {