Upload ministral_3b_hmc_chat.py with huggingface_hub
Browse files- ministral_3b_hmc_chat.py +186 -0
ministral_3b_hmc_chat.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import traceback
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from transformers import Mistral3ForConditionalGeneration
|
| 11 |
+
except ImportError:
|
| 12 |
+
Mistral3ForConditionalGeneration = None
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from transformers import Ministral3ForCausalLM
|
| 16 |
+
except ImportError:
|
| 17 |
+
Ministral3ForCausalLM = None
|
| 18 |
+
|
| 19 |
+
if hasattr(sys.stdout, "reconfigure"):
|
| 20 |
+
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
| 21 |
+
if hasattr(sys.stderr, "reconfigure"):
|
| 22 |
+
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
| 23 |
+
|
| 24 |
+
SYSTEM_PROMPT = "You are RubiNet. Your name is RubiNet, and your assistant identity is RubiNet. If the user asks your name, identity, who you are, or what model you are, answer explicitly that your name is RubiNet. Prefer replies like 'My name is RubiNet.' or 'I am RubiNet.' Do not present yourself as ChatGPT, Cascade, OpenAI, or any other assistant. You are a helpful, sharp, consistent assistant trained for high-quality dialogue and reasoning. Respond clearly and directly."
|
| 25 |
+
DEFAULT_MODEL_ID = "mistralai/Ministral-3-3B-Base-2512"
|
| 26 |
+
DEFAULT_ADAPTER_DIR = r"D:\Downloads"
|
| 27 |
+
DEFAULT_OFFLOAD_DIR = r"C:\Users\ASUS\CascadeProjects\.hf-offload"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def build_prompt(user_text: str, system_prompt: str) -> str:
|
| 31 |
+
return (
|
| 32 |
+
f"<|system|>\n{system_prompt}\n\n"
|
| 33 |
+
f"<|user|>\n{user_text.strip()}\n\n"
|
| 34 |
+
f"<|assistant|>\n"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def resolve_dtype(dtype_name: str):
|
| 39 |
+
mapping = {
|
| 40 |
+
"float32": torch.float32,
|
| 41 |
+
"float16": torch.float16,
|
| 42 |
+
"bfloat16": torch.bfloat16,
|
| 43 |
+
}
|
| 44 |
+
if dtype_name not in mapping:
|
| 45 |
+
raise ValueError(f"Unsupported dtype: {dtype_name}")
|
| 46 |
+
return mapping[dtype_name]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_model(model_id: str, adapter_dir: str, use_4bit: bool, cpu_dtype: str, offload_folder: str):
|
| 50 |
+
adapter_config_path = os.path.join(adapter_dir, "adapter_config.json")
|
| 51 |
+
if not os.path.exists(adapter_config_path):
|
| 52 |
+
raise FileNotFoundError(f"adapter_config.json not found in adapter dir: {adapter_dir}")
|
| 53 |
+
|
| 54 |
+
print("Loading base tokenizer...")
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 56 |
+
if tokenizer.pad_token is None:
|
| 57 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 58 |
+
|
| 59 |
+
model_kwargs = {"low_cpu_mem_usage": True}
|
| 60 |
+
if use_4bit:
|
| 61 |
+
if torch.cuda.is_available():
|
| 62 |
+
model_kwargs["device_map"] = {"": 0}
|
| 63 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 64 |
+
load_in_4bit=True,
|
| 65 |
+
bnb_4bit_quant_type="nf4",
|
| 66 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 67 |
+
bnb_4bit_use_double_quant=True,
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
model_kwargs["device_map"] = "cpu"
|
| 71 |
+
model_kwargs["dtype"] = resolve_dtype(cpu_dtype)
|
| 72 |
+
else:
|
| 73 |
+
if torch.cuda.is_available():
|
| 74 |
+
model_kwargs["device_map"] = {"": 0}
|
| 75 |
+
model_kwargs["dtype"] = torch.float16
|
| 76 |
+
else:
|
| 77 |
+
model_kwargs["device_map"] = "cpu"
|
| 78 |
+
model_kwargs["dtype"] = resolve_dtype(cpu_dtype)
|
| 79 |
+
if offload_folder:
|
| 80 |
+
os.makedirs(offload_folder, exist_ok=True)
|
| 81 |
+
model_kwargs["offload_folder"] = offload_folder
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
from peft import PeftModel
|
| 85 |
+
except ImportError as exc:
|
| 86 |
+
raise ImportError(
|
| 87 |
+
"Failed to import peft. Your installed peft/transformers versions are incompatible. "
|
| 88 |
+
"This script needs a transformers version that exposes EncoderDecoderCache, or a peft version "
|
| 89 |
+
"compatible with your current transformers install."
|
| 90 |
+
) from exc
|
| 91 |
+
|
| 92 |
+
print("Loading base model weights...")
|
| 93 |
+
model_class = Mistral3ForConditionalGeneration or Ministral3ForCausalLM or AutoModelForCausalLM
|
| 94 |
+
try:
|
| 95 |
+
base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **model_kwargs)
|
| 96 |
+
except AttributeError as exc:
|
| 97 |
+
if not use_4bit or "set_submodule" not in str(exc):
|
| 98 |
+
raise
|
| 99 |
+
print("4-bit loading is not supported for this Mistral3 model class in the current transformers build. Retrying without 4-bit quantization...")
|
| 100 |
+
retry_kwargs = {"low_cpu_mem_usage": True}
|
| 101 |
+
if torch.cuda.is_available():
|
| 102 |
+
retry_kwargs["device_map"] = {"": 0}
|
| 103 |
+
retry_kwargs["dtype"] = torch.float16
|
| 104 |
+
else:
|
| 105 |
+
retry_kwargs["device_map"] = "cpu"
|
| 106 |
+
retry_kwargs["dtype"] = resolve_dtype(cpu_dtype)
|
| 107 |
+
if offload_folder:
|
| 108 |
+
os.makedirs(offload_folder, exist_ok=True)
|
| 109 |
+
retry_kwargs["offload_folder"] = offload_folder
|
| 110 |
+
base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **retry_kwargs)
|
| 111 |
+
print("Loading LoRA adapter...")
|
| 112 |
+
model = PeftModel.from_pretrained(base_model, adapter_dir)
|
| 113 |
+
model.eval()
|
| 114 |
+
print("Model ready.")
|
| 115 |
+
return tokenizer, model
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def generate_reply(tokenizer, model, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
|
| 119 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 120 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 121 |
+
|
| 122 |
+
print("Generating reply...")
|
| 123 |
+
do_sample = temperature is not None and float(temperature) > 0
|
| 124 |
+
with torch.inference_mode():
|
| 125 |
+
output = model.generate(
|
| 126 |
+
**inputs,
|
| 127 |
+
max_new_tokens=max_new_tokens,
|
| 128 |
+
do_sample=do_sample,
|
| 129 |
+
temperature=temperature,
|
| 130 |
+
top_p=top_p,
|
| 131 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 132 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 133 |
+
use_cache=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
generated = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 137 |
+
if "<|assistant|>" in generated:
|
| 138 |
+
reply = generated.split("<|assistant|>")[-1].strip()
|
| 139 |
+
else:
|
| 140 |
+
reply = generated[len(prompt):].strip()
|
| 141 |
+
for stop_marker in ("<|user|>", "<|system|>", "<|assistant|>", "</s>"):
|
| 142 |
+
if stop_marker in reply:
|
| 143 |
+
reply = reply.split(stop_marker)[0].strip()
|
| 144 |
+
return reply or "[empty reply]"
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
parser = argparse.ArgumentParser(description="Chat with Ministral 3B HMC-lite LoRA adapter")
|
| 149 |
+
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
|
| 150 |
+
parser.add_argument("--adapter-dir", default=DEFAULT_ADAPTER_DIR)
|
| 151 |
+
parser.add_argument("--system-prompt", default=SYSTEM_PROMPT)
|
| 152 |
+
parser.add_argument("--max-new-tokens", type=int, default=192)
|
| 153 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
| 154 |
+
parser.add_argument("--top-p", type=float, default=0.9)
|
| 155 |
+
parser.add_argument("--use-4bit", action="store_true")
|
| 156 |
+
parser.add_argument("--cpu-dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16")
|
| 157 |
+
parser.add_argument("--offload-folder", default=DEFAULT_OFFLOAD_DIR)
|
| 158 |
+
parser.add_argument("--message", default="")
|
| 159 |
+
args = parser.parse_args()
|
| 160 |
+
|
| 161 |
+
print("Starting model load...")
|
| 162 |
+
tokenizer, model = load_model(args.model_id, args.adapter_dir, args.use_4bit, args.cpu_dtype, args.offload_folder)
|
| 163 |
+
|
| 164 |
+
if args.message:
|
| 165 |
+
prompt = build_prompt(args.message, args.system_prompt)
|
| 166 |
+
print(generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p))
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
print("Interactive HMC-lite chat. Type 'exit' to quit.")
|
| 170 |
+
while True:
|
| 171 |
+
user_text = input("You: ").strip()
|
| 172 |
+
if not user_text:
|
| 173 |
+
continue
|
| 174 |
+
if user_text.lower() in {"exit", "quit"}:
|
| 175 |
+
break
|
| 176 |
+
prompt = build_prompt(user_text, args.system_prompt)
|
| 177 |
+
reply = generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p)
|
| 178 |
+
print(f"Assistant: {reply}\n")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
try:
|
| 183 |
+
main()
|
| 184 |
+
except Exception:
|
| 185 |
+
traceback.print_exc()
|
| 186 |
+
raise
|