sususupa commited on
Commit
d4c15a1
·
verified ·
1 Parent(s): ce77ba5

Create summary_ko

Browse files
Files changed (1) hide show
  1. summary_ko +186 -0
summary_ko ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. 개발 환경 설정¶
2
+ # 1.1 필수 라이브러리 설치하기¶
3
+ In [ ]:
4
+ !pip3 install -q -U transformers==4.38.2
5
+ !pip3 install -q -U datasets==2.18.0
6
+ !pip3 install -q -U bitsandbytes==0.42.0
7
+ !pip3 install -q -U peft==0.9.0
8
+ !pip3 install -q -U trl==0.7.11
9
+ !pip3 install -q -U accelerate==0.27.2
10
+
11
+ # 1.2 Import modules¶
12
+ In [ ]:
13
+ import torch
14
+ from datasets import Dataset, load_dataset
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments
16
+ from peft import LoraConfig, PeftModel
17
+ from trl import SFTTrainer
18
+
19
+ # 1.3 Huggingface 로그인¶
20
+ In [ ]:
21
+ from huggingface_hub import notebook_login
22
+ notebook_login()
23
+
24
+ # 2. Dataset 생성 및 준비¶
25
+ # 2.1 데이터셋 로드¶
26
+ In [ ]:
27
+ from datasets import load_dataset
28
+ dataset = load_dataset("daekeun-ml/naver-news-summarization-ko")
29
+ # 2.2 데이터셋 탐색¶
30
+ In [ ]:
31
+ dataset
32
+ # 2.3 데이터셋 예시¶
33
+ In [ ]:
34
+ dataset['train'][0]
35
+
36
+ # 3. Gemma 모델의 한국어 요약 테스트¶
37
+ # 3.1 모델 로드¶
38
+ In [ ]:
39
+ BASE_MODEL = "google/gemma-2b-it"
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map={"":0})
42
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
43
+ # 3.2 Gemma-it의 프롬프트 형식¶
44
+ In [ ]:
45
+ doc = dataset['train']['document'][0]
46
+ In [ ]:
47
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
48
+ In [ ]:
49
+ messages = [
50
+ {
51
+ "role": "user",
52
+ "content": "다음 글을 요약해주세요 :\n\n{}".format(doc)
53
+ }
54
+ ]
55
+ prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
+ In [ ]:
57
+ prompt
58
+ # 3.3 Gemma-it 추론¶
59
+ In [ ]:
60
+ outputs = pipe(
61
+ prompt,
62
+ do_sample=True,
63
+ temperature=0.2,
64
+ top_k=50,
65
+ top_p=0.95,
66
+ add_special_tokens=True
67
+ )
68
+ In [ ]:
69
+ print(outputs[0]["generated_text"][len(prompt):])
70
+
71
+ # 4. Gemma 파인튜닝¶
72
+ 주의: Colab GPU 메모리 한계로 이전장 추론에서 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
73
+ notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다
74
+ In [ ]:
75
+ !nvidia-smi
76
+ # 4.1 학습용 프롬프트 조정¶
77
+ In [ ]:
78
+ def generate_prompt(example):
79
+ prompt_list = []
80
+ for i in range(len(example['document'])):
81
+ prompt_list.append(r"""<bos><start_of_turn>user
82
+ 다음 글을 요약해주세요:
83
+
84
+ {}<end_of_turn>
85
+ <start_of_turn>model
86
+ {}<end_of_turn><eos>""".format(example['document'][i], example['summary'][i]))
87
+ return prompt_list
88
+ In [ ]:
89
+ train_data = dataset['train']
90
+ print(generate_prompt(train_data[:1])[0])
91
+ # 4.2 QLoRA 설정¶
92
+ In [ ]:
93
+ lora_config = LoraConfig(
94
+ r=6,
95
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
96
+ task_type="CAUSAL_LM",
97
+ )
98
+
99
+ bnb_config = BitsAndBytesConfig(
100
+ load_in_4bit=True,
101
+ bnb_4bit_quant_type="nf4",
102
+ bnb_4bit_compute_dtype=torch.float16
103
+ )
104
+ In [ ]:
105
+ BASE_MODEL = "google/gemma-2b-it"
106
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto", quantization_config=bnb_config)
107
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
108
+ tokenizer.padding_side = 'right'
109
+ # 4.3 Trainer 실행¶
110
+ In [ ]:
111
+ trainer = SFTTrainer(
112
+ model=model,
113
+ train_dataset=train_data,
114
+ max_seq_length=512,
115
+ args=TrainingArguments(
116
+ output_dir="outputs",
117
+ # num_train_epochs = 1,
118
+ max_steps=3000,
119
+ per_device_train_batch_size=1,
120
+ gradient_accumulation_steps=4,
121
+ optim="paged_adamw_8bit",
122
+ warmup_steps=0.03,
123
+ learning_rate=2e-4,
124
+ fp16=True,
125
+ logging_steps=100,
126
+ push_to_hub=False,
127
+ report_to='none',
128
+ ),
129
+ peft_config=lora_config,
130
+ formatting_func=generate_prompt,
131
+ )
132
+ In [ ]:
133
+ trainer.train()
134
+ # 4.4 Finetuned Model 저장¶
135
+ In [ ]:
136
+ ADAPTER_MODEL = "lora_adapter"
137
+
138
+ trainer.model.save_pretrained(ADAPTER_MODEL)
139
+ In [ ]:
140
+ !ls -alh lora_adapter
141
+ In [ ]:
142
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)
143
+ model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)
144
+
145
+ model = model.merge_and_unload()
146
+ model.save_pretrained('gemma-2b-it-sum-ko')
147
+ In [ ]:
148
+ !ls -alh ./gemma-2b-it-sum-ko
149
+
150
+ # 5. Gemma 한국어 요약 모델 추론¶
151
+ 주의: 마찬가지로 Colab GPU 메모리 한계로 학습 시 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
152
+ notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다
153
+ In [ ]:
154
+ !nvidia-smi
155
+ # 5.1 Fine-tuned 모델 로드¶
156
+ In [ ]:
157
+ BASE_MODEL = "google/gemma-2b-it"
158
+ FINETUNE_MODEL = "./gemma-2b-it-sum-ko"
159
+
160
+ finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={"":0})
161
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)
162
+ # 5.2 Fine-tuned 모델 추론¶
163
+ In [ ]:
164
+ pipe_finetuned = pipeline("text-generation", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)
165
+ In [ ]:
166
+ doc = dataset['test']['document'][10]
167
+ In [ ]:
168
+ messages = [
169
+ {
170
+ "role": "user",
171
+ "content": "다음 글을 요약해주세요:\n\n{}".format(doc)
172
+ }
173
+ ]
174
+ prompt = pipe_finetuned.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
175
+ In [ ]:
176
+ outputs = pipe_finetuned(
177
+ prompt,
178
+ do_sample=True,
179
+ temperature=0.2,
180
+ top_k=50,
181
+ top_p=0.95,
182
+ add_special_tokens=True
183
+ )
184
+ print(outputs[0]["generated_text"][len(prompt):])
185
+ In [ ]:
186
+