Leesn465 commited on
Commit
5a95741
·
verified ·
1 Parent(s): 1097a00

Update keyword_module.py

Browse files
Files changed (1) hide show
  1. keyword_module.py +12 -7
keyword_module.py CHANGED
@@ -11,17 +11,22 @@ from bs4 import BeautifulSoup as bs
11
  summary_tokenizer = PreTrainedTokenizerFast.from_pretrained("gogamza/kobart-summarization")
12
  summary_model = BartForConditionalGeneration.from_pretrained("gogamza/kobart-summarization")
13
 
14
- def summarize_kobart(text, max_input_length=512):
15
- # 입력 자르기
16
- input_ids = summary_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_input_length)
 
 
 
 
 
17
 
18
  summary_ids = summary_model.generate(
19
- input_ids,
20
- max_length=160,
21
- min_length=100,
22
  num_beams=4,
23
  repetition_penalty=2.5,
24
- no_repeat_ngram_size=3,
25
  early_stopping=True,
26
  )
27
  return summary_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
11
  summary_tokenizer = PreTrainedTokenizerFast.from_pretrained("gogamza/kobart-summarization")
12
  summary_model = BartForConditionalGeneration.from_pretrained("gogamza/kobart-summarization")
13
 
14
+ def summarize_kobart(text):
15
+ # 입력 길이 제한(핵심)
16
+ inputs = summary_tokenizer(
17
+ text,
18
+ return_tensors="pt",
19
+ truncation=True,
20
+ max_length=512, # 모델에 맞게 조정 (512/1024 중 하나일 확률 큼)
21
+ )
22
 
23
  summary_ids = summary_model.generate(
24
+ **inputs,
25
+ max_new_tokens=160, # ✅ 출력 길이는 max_new_tokens로 관리 추천
26
+ min_new_tokens=100,
27
  num_beams=4,
28
  repetition_penalty=2.5,
29
+ no_repeat_ngram_size=4,
30
  early_stopping=True,
31
  )
32
  return summary_tokenizer.decode(summary_ids[0], skip_special_tokens=True)