WorldDisasterLM-8B / scripts /write_hf_configs.py
drdeveloper88's picture
Sync: correct languages (en/ne/es/fr/ar/hi/te/zh/ja/ko/pt), updated README, full source code
4b93901
Raw
History Blame Contribute Delete
9 kB
"""Write best-practice HuggingFace model config files to the given output directory."""
import json
import os
import sys
out_dir = sys.argv[1] if len(sys.argv) > 1 else "."
# ── 1. config.json ──────────────────────────────────────────────────────────
config = {
"_name_or_path": "drdeveloper88/WorldDisasterLM-8B",
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [128001, 128008, 128009],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 131072,
"mlp_bias": False,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_interleaved": False,
"rope_scaling": {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.0",
"use_cache": True,
"vocab_size": 128256,
"quantization_config": {
"quant_method": "bitsandbytes",
"load_in_4bit": True,
"load_in_8bit": False,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_storage": "uint8",
"llm_int8_threshold": 6.0,
"llm_int8_skip_modules": None,
"llm_int8_enable_fp32_cpu_offload": False,
"llm_int8_has_fp16_weight": False,
},
}
# ── 2. quantization_config.json (standalone BitsAndBytes NF4) ───────────────
quantization_config = {
"quant_method": "bitsandbytes",
"load_in_4bit": True,
"load_in_8bit": False,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_storage": "uint8",
"llm_int8_threshold": 6.0,
"llm_int8_skip_modules": None,
"llm_int8_enable_fp32_cpu_offload": False,
"llm_int8_has_fp16_weight": False,
}
# ── 3. adapter_config.json (full PEFT LoRA) ──────────────────────────────────
adapter_config = {
"_version": "0.7.1",
"alpha_pattern": {},
"auto_mapping": None,
"base_model_name_or_path": "meta-llama/Llama-3.1-8B-Instruct",
"bias": "none",
"fan_in_fan_out": False,
"inference_mode": True,
"init_lora_weights": True,
"layer_replication": None,
"loftq_config": {},
"lora_alpha": 32,
"lora_dropout": 0.05,
"modules_to_save": None,
"peft_type": "LORA",
"r": 16,
"rank_pattern": {},
"revision": None,
"target_modules": [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
"task_type": "CAUSAL_LM",
"use_dora": False,
"use_rslora": False,
"trainable_parameters": "41,943,040",
"total_parameters": "8,030,261,248",
"trainable_pct": "0.52%",
}
# ── 4. tokenizer_config.json with disaster-domain system prompt ──────────────
SYSTEM_PROMPT = (
"You are WorldDisasterLM-8B, an expert AI specialized in global disaster "
"management, emergency response, and humanitarian aid. You provide accurate, "
"actionable guidance in 11 languages including Nepali, Hindi, Arabic, French, "
"Spanish, Swahili, Indonesian, Portuguese, Chinese, and Bengali. "
"Always prioritize life safety. Cite authoritative sources (NDRRMA for Nepal, "
"WHO, FEMA, GDACS, USGS) when relevant. Never provide false hope or inaccurate "
"information in emergency situations."
)
# Jinja2 chat template with hardcoded disaster system prompt as default
CHAT_TEMPLATE = (
"{%- set default_system = \"" + SYSTEM_PROMPT.replace('"', "'") + "\" %}"
"{%- if messages[0]['role'] == 'system' %}"
"{%- set default_system = messages[0]['content'] %}"
"{%- set messages = messages[1:] %}"
"{%- endif %}"
"{{ bos_token }}"
"<|start_header_id|>system<|end_header_id|>\n\n{{ default_system }}<|eot_id|>"
"{%- for message in messages %}"
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"
"{%- endfor %}"
"{%- if add_generation_prompt %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
"{%- endif %}"
)
tokenizer_config = {
"add_bos_token": True,
"add_eos_token": False,
"add_prefix_space": None,
"bos_token": "<|begin_of_text|>",
"chat_template": CHAT_TEMPLATE,
"clean_up_tokenization_spaces": True,
"eos_token": "<|eot_id|>",
"model_max_length": 131072,
"pad_token": "<|end_of_text|>",
"padding_side": "right",
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": None,
}
# ── 5. generation_config.json ────────────────────────────────────────────────
generation_config = {
"_from_model_config": False,
"bos_token_id": 128000,
"do_sample": True,
"eos_token_id": [128001, 128008, 128009],
"max_new_tokens": 512,
"min_new_tokens": 10,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1,
"no_repeat_ngram_size": 3,
"transformers_version": "4.43.0",
}
# ── 6. special_tokens_map.json ───────────────────────────────────────────────
special_tokens_map = {
"bos_token": {
"content": "<|begin_of_text|>",
"lstrip": False, "normalized": False, "rstrip": False, "single_word": False,
},
"eos_token": {
"content": "<|eot_id|>",
"lstrip": False, "normalized": False, "rstrip": False, "single_word": False,
},
"pad_token": {
"content": "<|end_of_text|>",
"lstrip": False, "normalized": False, "rstrip": False, "single_word": False,
},
"additional_special_tokens": [
"<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>",
"<|begin_of_text|>", "<|end_of_text|>",
],
}
# ── 7. training_args.json — QLoRA hyperparameters ────────────────────────────
training_args = {
"model_name_or_path": "meta-llama/Llama-3.1-8B-Instruct",
"output_dir": "./outputs/WorldDisasterLM-8B",
"num_train_epochs": 3,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"learning_rate": 2e-4,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.03,
"weight_decay": 0.001,
"max_grad_norm": 0.3,
"optim": "paged_adamw_32bit",
"fp16": False,
"bf16": True,
"max_seq_length": 4096,
"packing": True,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
"use_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": "bfloat16",
"use_nested_quant": True,
"save_steps": 100,
"logging_steps": 25,
"evaluation_strategy": "steps",
"eval_steps": 100,
"save_total_limit": 3,
"load_best_model_at_end": True,
"metric_for_best_model": "eval_loss",
"dataloader_num_workers": 4,
"seed": 42,
"report_to": ["tensorboard"],
"dataset_sources": ["ReliefWeb", "USGS", "GDACS", "NOAA", "OpenFEMA", "WHO"],
"languages": ["en", "ne", "hi", "ar", "fr", "es", "sw", "id", "pt", "zh", "bn"],
"training_status": "PENDING — weights not yet generated. Run train.py to produce weights.",
}
# ── Write all files ───────────────────────────────────────────────────────────
files = {
"config.json": config,
"quantization_config.json": quantization_config,
"adapter_config.json": adapter_config,
"tokenizer_config.json": tokenizer_config,
"generation_config.json": generation_config,
"special_tokens_map.json": special_tokens_map,
"training_args.json": training_args,
}
for fname, data in files.items():
path = os.path.join(out_dir, fname)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
size = os.path.getsize(path)
print(f" {fname:35s} {size:>6} bytes")
print(f"\nAll {len(files)} config files written to: {os.path.abspath(out_dir)}")