File size: 739 Bytes
7d6a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch

CONFIG = {
    "model_type": "d3pm_cross_attention",
    "data": {
        "include_negative_examples": True,
        "dataset_size": 60000,
    },
    "diffusion": {
        "mask_token_id": 0,
    },
    "model": {
        "src_vocab_size": 16000,
        "tgt_vocab_size": 16000,
        "d_model": 384,
        "n_heads": 8,
        "d_ff": 1536,
        "n_layers": 6,
        "dropout": 0.1,
        "max_seq_len": 80,
        "diffusion_steps": 64,
    },
    "training": {
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    },
    "inference": {
        "num_steps": 64,
        "temperature": 0.7,
        "top_k": 40,
        "repetition_penalty": 1.2,
        "diversity_penalty": 0.0,
    },
}