nluai commited on
Commit
94c6410
·
verified ·
1 Parent(s): d115d6c

Upload 6 files

Browse files
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model description
2
+ This model is a sequence-to-sequence question generator that takes an answer and context as an input and generates a question as an output. It is based on a pre-trained mt5-base by [Google](https://github.com/google-research/multilingual-t5) model.
3
+
4
+ ## Training data
5
+ The model was fine-tuned on [XQuAD](https://github.com/deepmind/xquad)
6
+
7
+ ## Example usage
8
+ ```python
9
+ from transformers import MT5ForConditionalGeneration, AutoTokenizer
10
+ import torch
11
+
12
+ model = MT5ForConditionalGeneration.from_pretrained("nluai/question-generation-vietnamese")
13
+ tokenizer = AutoTokenizer.from_pretrained("nluai/question-generation-vietnamese")
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+ # Content used to create a set of questions
19
+ context = '''Thành phố Hồ Chí Minh (còn gọi là Sài Gòn) tên gọi cũ trước 1975 là Sài Gòn hay Sài Gòn-Gia Định là thành phố lớn nhất ở Việt Nam về dân số và quy mô đô thị hóa. Đây còn là trung tâm kinh tế, chính trị, văn hóa và giáo dục tại Việt Nam. Thành phố Hồ Chí Minh là thành phố trực thuộc trung ương thuộc loại đô thị đặc biệt của Việt Nam cùng với thủ đô Hà Nội.Nằm trong vùng chuyển tiếp giữa Đông Nam Bộ và Tây Nam Bộ, thành phố này hiện có 16 quận, 1 thành phố và 5 huyện, tổng diện tích 2.061 km². Theo kết quả điều tra dân số chính thức vào thời điểm ngày một tháng 4 năm 2009 thì dân số thành phố là 7.162.864 người (chiếm 8,34% dân số Việt Nam), mật độ dân số trung bình 3.419 người/km². Đến năm 2019, dân số thành phố tăng lên 8.993.082 người và cũng là nơi có mật độ dân số cao nhất Việt Nam. Tuy nhiên, nếu tính những người cư trú không đăng ký hộ khẩu thì dân số thực tế của thành phố này năm 2018 là gần 14 triệu người.'''
20
+
21
+ encoding = tokenizer.encode_plus(context, return_tensors="pt")
22
+
23
+ input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
24
+
25
+ output = model.generate(input_ids=input_ids, attention_mask=attention_masks, max_length=256)
26
+
27
+ question = tokenizer.decode(output[0], skip_special_tokens=True,clean_up_tokenization_spaces=True)
28
+
29
+ question
30
+ #question: Thành phố hồ chí minh có bao nhiêu quận?
31
+ ```
32
+
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/mt5-base",
3
+ "architectures": ["MT5ForConditionalGeneration"],
4
+ "d_ff": 2048,
5
+ "d_kv": 64,
6
+ "d_model": 768,
7
+ "decoder_start_token_id": 0,
8
+ "dropout_rate": 0.1,
9
+ "eos_token_id": 1,
10
+ "feed_forward_proj": "gated-gelu",
11
+ "initializer_factor": 1,
12
+ "is_encoder_decoder": true,
13
+ "layer_norm_epsilon": 0.000001,
14
+ "model_type": "mt5",
15
+ "num_decoder_layers": 12,
16
+ "num_heads": 12,
17
+ "num_layers": 12,
18
+ "output_past": true,
19
+ "pad_token_id": 0,
20
+ "relative_attention_num_buckets": 32,
21
+ "tie_word_embeddings": false,
22
+ "tokenizer_class": "T5Tokenizer",
23
+ "vocab_size": 250112
24
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bd8f585f240e065d12a018fe234b64d49007dd59c51b6af068f386da4f76962
3
+ size 2329707353
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
3
+ size 4309802
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "/home/patrick/.cache/torch/transformers/685ac0ca8568ec593a48b61b0a3c272beee9bc194a3c7241d15dcadb5f875e53.f76030f3ec1b96a8199b2593390c610e76ca8028ef3d24680000619ffb646276", "name_or_path": "google/mt5-base"}