manthilaffs commited on
Commit
04d4513
·
verified ·
1 Parent(s): aaa6d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -33
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import spaces
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
- from threading import Thread
6
 
7
  model = None
8
  tokenizer = None
@@ -16,21 +15,6 @@ alpaca_prompt = """පහත දැක්වෙන්නේ යම් කාර
16
  {}"""
17
 
18
  @spaces.GPU
19
- def generate_with_streaming(inputs, max_new_tokens):
20
- global model
21
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
22
-
23
- generation_kwargs = dict(
24
- **inputs,
25
- max_new_tokens=max_new_tokens,
26
- streamer=streamer,
27
- )
28
-
29
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
30
- thread.start()
31
-
32
- return streamer, thread
33
-
34
  def infer(message, history, enable_history=False, max_new_tokens=512):
35
  global model, tokenizer
36
 
@@ -61,25 +45,15 @@ def infer(message, history, enable_history=False, max_new_tokens=512):
61
 
62
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
 
64
- # Get streamer and thread from GPU function
65
- streamer, thread = generate_with_streaming(inputs, max_new_tokens)
66
 
67
- # Stream the output
68
- partial_text = ""
69
- response_started = False
70
 
71
- for new_text in streamer:
72
- partial_text += new_text
73
-
74
- # Check if we've reached the response section
75
- if not response_started and "### ප්‍රතිචාරය:" in partial_text:
76
- partial_text = partial_text.split("### ප්‍රතිචාරය:")[-1].strip()
77
- response_started = True
78
-
79
- if response_started:
80
- yield partial_text
81
 
82
- thread.join()
83
 
84
  # Custom CSS for styling
85
  custom_css = """
 
1
  import gradio as gr
2
  import torch
3
  import spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
  model = None
7
  tokenizer = None
 
15
  {}"""
16
 
17
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def infer(message, history, enable_history=False, max_new_tokens=512):
19
  global model, tokenizer
20
 
 
45
 
46
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47
 
48
+ with torch.inference_mode():
49
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
50
 
51
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
52
 
53
+ if "### ප්‍රතිචාරය:" in text:
54
+ text = text.split("### ප්‍රතිචාරය:")[-1].strip()
 
 
 
 
 
 
 
 
55
 
56
+ return text
57
 
58
  # Custom CSS for styling
59
  custom_css = """