Asanaly commited on
Commit
6de6b50
·
verified ·
1 Parent(s): 72e1647

Update summarizer.py

Browse files
Files changed (1) hide show
  1. summarizer.py +7 -20
summarizer.py CHANGED
@@ -1,29 +1,16 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BartTokenizer, BartForConditionalGeneration
2
 
3
- # Английская модель
4
- eng_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
5
- eng_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
6
 
7
- # Русская модель (публичная)
8
- rus_tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/mbart_ru_sum_gazeta")
9
- rus_model = AutoModelForSeq2SeqLM.from_pretrained("IlyaGusev/mbart_ru_sum_gazeta")
10
-
11
- def generate_summary(text: str, lang="en", max_length=200, min_length=50) -> str:
12
- if lang == "ru":
13
- tokenizer = rus_tokenizer
14
- model = rus_model
15
- else:
16
- tokenizer = eng_tokenizer
17
- model = eng_model
18
-
19
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
20
 
21
  summary_ids = model.generate(
22
  inputs["input_ids"],
23
- max_length=max_length,
24
- min_length=min_length,
25
- length_penalty=2.0,
26
  num_beams=4,
 
 
27
  early_stopping=True
28
  )
29
 
 
1
+ from transformers import BartTokenizer, BartForConditionalGeneration
2
 
3
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
4
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
 
5
 
6
+ def generate_summary(text: str) -> str:
7
+ inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  summary_ids = model.generate(
10
  inputs["input_ids"],
 
 
 
11
  num_beams=4,
12
+ min_length=40,
13
+ max_length=200,
14
  early_stopping=True
15
  )
16