twnlp commited on
Commit
b306957
·
verified ·
1 Parent(s): 43be024

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +31 -16
README.md CHANGED
@@ -12,25 +12,40 @@ pip install transformers
12
  ```
13
 
14
  ```python
15
- # pip install transformers
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
- checkpoint = "twnlp/ChineseErrorCorrector2-7B"
18
 
19
- device = "cuda" # for GPU usage or "cpu" for CPU usage
20
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
21
- model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- input_content = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:\n少先队员因该为老人让坐。"
24
-
25
- messages = [{"role": "user", "content": input_content}]
26
- input_text=tokenizer.apply_chat_template(messages, tokenize=False)
27
-
28
- print(input_text)
29
-
30
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
31
- outputs = model.generate(inputs, max_new_tokens=1024, temperature=0, do_sample=False, repetition_penalty=1.08)
32
-
33
- print(tokenizer.decode(outputs[0]))
34
  ```
35
 
36
  output:
 
12
  ```
13
 
14
  ```python
 
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
16
 
17
+ model_name = "twnlp/ChineseErrorCorrector2-7B"
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype="auto",
22
+ device_map="auto"
23
+ )
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
+ prompt = "你是一个文本纠错专家,纠正输入句子中的语法错误,并输出正确的句子,输入句子为:"
27
+ text_input = "少先队员因该为老人让坐。"
28
+ messages = [
29
+ {"role": "user", "content": prompt + text_input}
30
+ ]
31
+ text = tokenizer.apply_chat_template(
32
+ messages,
33
+ tokenize=False,
34
+ add_generation_prompt=True
35
+ )
36
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
37
+
38
+ generated_ids = model.generate(
39
+ **model_inputs,
40
+ max_new_tokens=512
41
+ )
42
+ generated_ids = [
43
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
44
+ ]
45
+
46
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
47
+ print(response)
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  ```
50
 
51
  output: