ysharma HF Staff commited on
Commit
6fd0047
·
1 Parent(s): 350aa40
Files changed (1) hide show
  1. app.py +17 -44
app.py CHANGED
@@ -1,47 +1,31 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
- import time
5
- import numpy as np
6
- from torch.nn import functional as F
7
- import os
8
  from threading import Thread
9
 
10
- # init
11
- tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
12
- m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16) #load_in_8bit=True)
13
- m = m.to('cuda:0')
14
 
15
  class StopOnTokens(StoppingCriteria):
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
- #stop_ids = [[29, 13961, 31], [29, 12042, 31], 1, 0]
18
  stop_ids = [29, 0]
19
  for stop_id in stop_ids:
20
- #print(f"^^input ids - {input_ids}")
21
  if input_ids[0][-1] == stop_id:
22
  return True
23
  return False
24
 
25
- def chat(message, history):
26
 
27
- print(f"chatbot : {history}")
28
- print(f"message : {message}")
29
- history = history + [[message, ""]]
30
- print(f"chatbot : {history}")
31
-
32
- # Initialize a StopOnTokens object
33
  stop = StopOnTokens()
34
 
35
- # Construct the input message string for the model by concatenating the current system message and conversation history
36
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
37
- for item in history])
38
- print(f"messages : {messages}")
39
 
40
- # Tokenize the messages string
41
  model_inputs = tok([messages], return_tensors="pt").to("cuda")
42
-
43
  streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
44
-
45
  generate_kwargs = dict(
46
  model_inputs,
47
  streamer=streamer,
@@ -53,26 +37,15 @@ def chat(message, history):
53
  num_beams=1,
54
  stopping_criteria=StoppingCriteriaList([stop])
55
  )
56
- t = Thread(target=m.generate, kwargs=generate_kwargs)
57
  t.start()
58
 
59
- # Initialize an empty string to store the generated text
60
- partial_text = ""
61
- for new_text in streamer:
62
- print(new_text)
63
- if new_text != '<':
64
- partial_text += new_text
65
- #history[-1][1] = partial_text.split('<bot>:')[-1]
66
- # Yield an empty string to clean up the message textbox and the updated conversation history
67
- yield partial_text
68
- #return partial_text
69
 
70
 
71
- title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃‍♂️💨Streaming with Transformers & Gradio💪</h2>"""
72
- description = """<br><br><h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
73
- theme = gr.themes.Soft(
74
- primary_hue=gr.themes.Color("#ededed", "#fee2e2", "#fecaca", "#fca5a5", "#f87171", "#ef4444", "#dc2626", "#b91c1c", "#991b1b", "#7f1d1d", "#6c1e1e"),
75
- neutral_hue="red",
76
- )
77
-
78
- gr.ChatInterface(chat, delete_last_btn="❌Delete").queue().launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
 
 
 
 
 
3
  from threading import Thread
4
 
5
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
6
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
7
+ model = model.to('cuda:0')
 
8
 
9
  class StopOnTokens(StoppingCriteria):
10
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
11
  stop_ids = [29, 0]
12
  for stop_id in stop_ids:
 
13
  if input_ids[0][-1] == stop_id:
14
  return True
15
  return False
16
 
17
+ def predict(message, history):
18
 
19
+ history_transformer_format = history + [[message, ""]]
 
 
 
 
 
20
  stop = StopOnTokens()
21
 
22
+ #Construct the input message string for the model by concatenating the current system message and conversation history
23
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
24
+ for item in history_transformer_format])
 
25
 
26
+ #Tokenize the messages string
27
  model_inputs = tok([messages], return_tensors="pt").to("cuda")
 
28
  streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
29
  generate_kwargs = dict(
30
  model_inputs,
31
  streamer=streamer,
 
37
  num_beams=1,
38
  stopping_criteria=StoppingCriteriaList([stop])
39
  )
40
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
41
  t.start()
42
 
43
+ #Initialize an empty string to store the generated text
44
+ partial_message = ""
45
+ for new_token in streamer:
46
+ if new_token != '<':
47
+ partial_message += new_token
48
+ yield partial_message
 
 
 
 
49
 
50
 
51
+ gr.ChatInterface(predict).queue().launch()