rinrikatoki commited on
Commit
7256add
·
verified ·
1 Parent(s): 6ad1c48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -39
app.py CHANGED
@@ -1,39 +1,39 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
- import gradio as gr
4
-
5
- model_id = "rinrikatoki/dorna-merged-full"
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
8
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
9
- model = model.eval()
10
-
11
- def chat(message, history):
12
- if history is None:
13
- history = []
14
-
15
- prompt = ""
16
- for user, bot in history:
17
- prompt += f"<|user|>\n{user}\n<|assistant|>\n{bot}\n"
18
- prompt += f"<|user|>\n{message}\n<|assistant|>\n"
19
-
20
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
21
- input_ids = input_ids.to(model.device)
22
-
23
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
24
-
25
- output = model.generate(
26
- input_ids,
27
- max_new_tokens=512,
28
- temperature=0.7,
29
- top_p=0.95,
30
- do_sample=True,
31
- streamer=streamer,
32
- )
33
-
34
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
35
- bot_reply = output_text.split("<|assistant|>")[-1].strip()
36
- history.append((message, bot_reply))
37
- return "", history
38
-
39
- gr.ChatInterface(chat).launch()
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
+ import gradio as gr
4
+
5
+ model_id = "rinrikatoki/dorna-merged-4bit"
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
9
+ model = model.eval()
10
+
11
+ def chat(message, history):
12
+ if history is None:
13
+ history = []
14
+
15
+ prompt = ""
16
+ for user, bot in history:
17
+ prompt += f"<|user|>\n{user}\n<|assistant|>\n{bot}\n"
18
+ prompt += f"<|user|>\n{message}\n<|assistant|>\n"
19
+
20
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
21
+ input_ids = input_ids.to(model.device)
22
+
23
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
24
+
25
+ output = model.generate(
26
+ input_ids,
27
+ max_new_tokens=512,
28
+ temperature=0.7,
29
+ top_p=0.95,
30
+ do_sample=True,
31
+ streamer=streamer,
32
+ )
33
+
34
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
35
+ bot_reply = output_text.split("<|assistant|>")[-1].strip()
36
+ history.append((message, bot_reply))
37
+ return "", history
38
+
39
+ gr.ChatInterface(chat).launch()