ysharma HF Staff commited on
Commit
5638bd8
·
1 Parent(s): 72ad694

added chat interface

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -9,8 +9,8 @@ from threading import Thread
9
 
10
  # init
11
  tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
12
- m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
13
- m = m.to('cuda:0')
14
 
15
  class StopOnTokens(StoppingCriteria):
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@@ -28,30 +28,35 @@ def user(message, history):
28
  return "", history + [[message, ""]]
29
 
30
 
 
 
 
 
 
31
 
32
- def chat(history, top_p, top_k, temperature):
33
  # Initialize a StopOnTokens object
34
  stop = StopOnTokens()
35
 
36
  # Construct the input message string for the model by concatenating the current system message and conversation history
37
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
38
  for item in history])
39
-
40
  # Tokenize the messages string
41
  model_inputs = tok([messages], return_tensors="pt").to("cuda")
42
- streamer = TextIteratorStreamer(
43
- tok, timeout=10., skip_prompt=False, skip_special_tokens=True)
 
44
  generate_kwargs = dict(
45
  model_inputs,
46
  streamer=streamer,
47
  max_new_tokens=1024,
48
  do_sample=True,
49
- top_p=top_p, #0.95,
50
- top_k=top_k, #1000,
51
- temperature=temperature, #1.0,
52
  num_beams=1,
53
  stopping_criteria=StoppingCriteriaList([stop])
54
- )
55
  t = Thread(target=m.generate, kwargs=generate_kwargs)
56
  t.start()
57
 
@@ -61,11 +66,11 @@ def chat(history, top_p, top_k, temperature):
61
  #print(new_text)
62
  if new_text != '<':
63
  partial_text += new_text
64
- history[-1][1] = partial_text.split('<bot>:')[-1]
65
  # Yield an empty string to clean up the message textbox and the updated conversation history
66
- yield history
67
- return partial_text
68
-
69
 
70
  title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃‍♂️💨Streaming with Transformers & Gradio💪</h2>"""
71
  description = """<br><br><h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
@@ -74,6 +79,7 @@ theme = gr.themes.Soft(
74
  neutral_hue="red",
75
  )
76
 
 
77
 
78
  with gr.Blocks(theme=theme) as demo:
79
  gr.HTML(title)
@@ -113,5 +119,5 @@ with gr.Blocks(theme=theme) as demo:
113
  )
114
  gr.HTML(description)
115
 
116
- demo.queue(max_size=32, concurrency_count=2)
117
- demo.launch(debug=True)
 
9
 
10
  # init
11
  tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
12
+ m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", load_in_8bit=True) #torch_dtype=torch.float16)
13
+ #m = m.to('cuda:0')
14
 
15
  class StopOnTokens(StoppingCriteria):
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
28
  return "", history + [[message, ""]]
29
 
30
 
31
+ def chat(message, history):
32
+
33
+ print(f"chatbot : {history}")
34
+ #history = history + [[message, ""]]
35
+ #print(f"chatbot : {history}")
36
 
 
37
  # Initialize a StopOnTokens object
38
  stop = StopOnTokens()
39
 
40
  # Construct the input message string for the model by concatenating the current system message and conversation history
41
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
42
  for item in history])
43
+
44
  # Tokenize the messages string
45
  model_inputs = tok([messages], return_tensors="pt").to("cuda")
46
+
47
+ streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=False, skip_special_tokens=True)
48
+
49
  generate_kwargs = dict(
50
  model_inputs,
51
  streamer=streamer,
52
  max_new_tokens=1024,
53
  do_sample=True,
54
+ top_p=0.95,
55
+ top_k=1000,
56
+ temperature=1.0,
57
  num_beams=1,
58
  stopping_criteria=StoppingCriteriaList([stop])
59
+ )
60
  t = Thread(target=m.generate, kwargs=generate_kwargs)
61
  t.start()
62
 
 
66
  #print(new_text)
67
  if new_text != '<':
68
  partial_text += new_text
69
+ #history[-1][1] = partial_text.split('<bot>:')[-1]
70
  # Yield an empty string to clean up the message textbox and the updated conversation history
71
+ yield partial_text
72
+ #return partial_text
73
+
74
 
75
  title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃‍♂️💨Streaming with Transformers & Gradio💪</h2>"""
76
  description = """<br><br><h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
 
79
  neutral_hue="red",
80
  )
81
 
82
+ gr.ChatInterface(chat, delete_last_btn="❌Delete").queue().launch(debug=True)
83
 
84
  with gr.Blocks(theme=theme) as demo:
85
  gr.HTML(title)
 
119
  )
120
  gr.HTML(description)
121
 
122
+ #demo.queue(max_size=32, concurrency_count=2)
123
+ #demo.launch(debug=True)