RubiNet / ministral_3b_hmc_chat.py
DevHunterAI's picture
Upload ministral_3b_hmc_chat.py with huggingface_hub
1335cf8 verified
import argparse
import os
import sys
import traceback
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
try:
from transformers import Mistral3ForConditionalGeneration
except ImportError:
Mistral3ForConditionalGeneration = None
try:
from transformers import Ministral3ForCausalLM
except ImportError:
Ministral3ForCausalLM = None
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
SYSTEM_PROMPT = "You are RubiNet. Your name is RubiNet, and your assistant identity is RubiNet. If the user asks your name, identity, who you are, or what model you are, answer explicitly that your name is RubiNet. Prefer replies like 'My name is RubiNet.' or 'I am RubiNet.' Do not present yourself as ChatGPT, Cascade, OpenAI, or any other assistant. You are a helpful, sharp, consistent assistant trained for high-quality dialogue and reasoning. Respond clearly and directly."
DEFAULT_MODEL_ID = "mistralai/Ministral-3-3B-Base-2512"
DEFAULT_ADAPTER_DIR = r"D:\Downloads"
DEFAULT_OFFLOAD_DIR = r"C:\Users\ASUS\CascadeProjects\.hf-offload"
def build_prompt(user_text: str, system_prompt: str) -> str:
return (
f"<|system|>\n{system_prompt}\n\n"
f"<|user|>\n{user_text.strip()}\n\n"
f"<|assistant|>\n"
)
def resolve_dtype(dtype_name: str):
mapping = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
if dtype_name not in mapping:
raise ValueError(f"Unsupported dtype: {dtype_name}")
return mapping[dtype_name]
def load_model(model_id: str, adapter_dir: str, use_4bit: bool, cpu_dtype: str, offload_folder: str):
adapter_config_path = os.path.join(adapter_dir, "adapter_config.json")
if not os.path.exists(adapter_config_path):
raise FileNotFoundError(f"adapter_config.json not found in adapter dir: {adapter_dir}")
print("Loading base tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {"low_cpu_mem_usage": True}
if use_4bit:
if torch.cuda.is_available():
model_kwargs["device_map"] = {"": 0}
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
else:
model_kwargs["device_map"] = "cpu"
model_kwargs["dtype"] = resolve_dtype(cpu_dtype)
else:
if torch.cuda.is_available():
model_kwargs["device_map"] = {"": 0}
model_kwargs["dtype"] = torch.float16
else:
model_kwargs["device_map"] = "cpu"
model_kwargs["dtype"] = resolve_dtype(cpu_dtype)
if offload_folder:
os.makedirs(offload_folder, exist_ok=True)
model_kwargs["offload_folder"] = offload_folder
try:
from peft import PeftModel
except ImportError as exc:
raise ImportError(
"Failed to import peft. Your installed peft/transformers versions are incompatible. "
"This script needs a transformers version that exposes EncoderDecoderCache, or a peft version "
"compatible with your current transformers install."
) from exc
print("Loading base model weights...")
model_class = Mistral3ForConditionalGeneration or Ministral3ForCausalLM or AutoModelForCausalLM
try:
base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **model_kwargs)
except AttributeError as exc:
if not use_4bit or "set_submodule" not in str(exc):
raise
print("4-bit loading is not supported for this Mistral3 model class in the current transformers build. Retrying without 4-bit quantization...")
retry_kwargs = {"low_cpu_mem_usage": True}
if torch.cuda.is_available():
retry_kwargs["device_map"] = {"": 0}
retry_kwargs["dtype"] = torch.float16
else:
retry_kwargs["device_map"] = "cpu"
retry_kwargs["dtype"] = resolve_dtype(cpu_dtype)
if offload_folder:
os.makedirs(offload_folder, exist_ok=True)
retry_kwargs["offload_folder"] = offload_folder
base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **retry_kwargs)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, adapter_dir)
model.eval()
print("Model ready.")
return tokenizer, model
def generate_reply(tokenizer, model, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
print("Generating reply...")
do_sample = temperature is not None and float(temperature) > 0
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
generated = tokenizer.decode(output[0], skip_special_tokens=False)
if "<|assistant|>" in generated:
reply = generated.split("<|assistant|>")[-1].strip()
else:
reply = generated[len(prompt):].strip()
for stop_marker in ("<|user|>", "<|system|>", "<|assistant|>", "</s>"):
if stop_marker in reply:
reply = reply.split(stop_marker)[0].strip()
return reply or "[empty reply]"
def main():
parser = argparse.ArgumentParser(description="Chat with Ministral 3B HMC-lite LoRA adapter")
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
parser.add_argument("--adapter-dir", default=DEFAULT_ADAPTER_DIR)
parser.add_argument("--system-prompt", default=SYSTEM_PROMPT)
parser.add_argument("--max-new-tokens", type=int, default=192)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--use-4bit", action="store_true")
parser.add_argument("--cpu-dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16")
parser.add_argument("--offload-folder", default=DEFAULT_OFFLOAD_DIR)
parser.add_argument("--message", default="")
args = parser.parse_args()
print("Starting model load...")
tokenizer, model = load_model(args.model_id, args.adapter_dir, args.use_4bit, args.cpu_dtype, args.offload_folder)
if args.message:
prompt = build_prompt(args.message, args.system_prompt)
print(generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p))
return
print("Interactive HMC-lite chat. Type 'exit' to quit.")
while True:
user_text = input("You: ").strip()
if not user_text:
continue
if user_text.lower() in {"exit", "quit"}:
break
prompt = build_prompt(user_text, args.system_prompt)
reply = generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p)
print(f"Assistant: {reply}\n")
if __name__ == "__main__":
try:
main()
except Exception:
traceback.print_exc()
raise