DevaFlow / config.py
bhsinghgrid's picture
Add files using upload-large-folder tool
7d6a683 verified
raw
history blame contribute delete
739 Bytes
import torch
CONFIG = {
"model_type": "d3pm_cross_attention",
"data": {
"include_negative_examples": True,
"dataset_size": 60000,
},
"diffusion": {
"mask_token_id": 0,
},
"model": {
"src_vocab_size": 16000,
"tgt_vocab_size": 16000,
"d_model": 384,
"n_heads": 8,
"d_ff": 1536,
"n_layers": 6,
"dropout": 0.1,
"max_seq_len": 80,
"diffusion_steps": 64,
},
"training": {
"device": "cuda" if torch.cuda.is_available() else "cpu",
},
"inference": {
"num_steps": 64,
"temperature": 0.7,
"top_k": 40,
"repetition_penalty": 1.2,
"diversity_penalty": 0.0,
},
}