twin2ryo commited on
Commit
43aa96d
·
verified ·
1 Parent(s): 651a46b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -8
README.md CHANGED
@@ -17,18 +17,42 @@ It has been trained using [TRL](https://github.com/huggingface/trl).
17
  ## Quick start
18
 
19
  ```python
20
- from transformers import pipeline
21
 
22
- question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
23
- generator = pipeline("text-generation", model="oriental-lab/TinySwallow-1.5B-dolly", device="cuda")
24
- output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
25
- print(output["generated_text"])
26
- ```
27
 
28
- ## Training procedure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/twin1shun/huggingface/runs/wbfofu5p)
31
 
 
 
 
32
 
33
  This model was trained with SFT.
34
 
 
17
  ## Quick start
18
 
19
  ```python
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
+ model_name = "oriental-lab/TinySwallow-1.5B-dolly"
 
 
 
 
23
 
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ torch_dtype="auto",
27
+ device_map="auto"
28
+ )
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+
31
+ prompt = "Do plants have life?"
32
+ messages = [
33
+ {"role": "system", "content": "You are a helpful assistant."},
34
+ {"role": "user", "content": prompt}
35
+ ]
36
+ text = tokenizer.apply_chat_template(
37
+ messages,
38
+ tokenize=False,
39
+ add_generation_prompt=True
40
+ )
41
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
42
+
43
+ generated_ids = model.generate(
44
+ **model_inputs,
45
+ max_new_tokens=512
46
+ )
47
+ generated_ids = [
48
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
49
+ ]
50
 
51
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
 
53
+ ```
54
+
55
+ ## Training procedure
56
 
57
  This model was trained with SFT.
58