import argparse import os import sys import traceback import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig try: from transformers import Mistral3ForConditionalGeneration except ImportError: Mistral3ForConditionalGeneration = None try: from transformers import Ministral3ForCausalLM except ImportError: Ministral3ForCausalLM = None if hasattr(sys.stdout, "reconfigure"): sys.stdout.reconfigure(encoding="utf-8", errors="replace") if hasattr(sys.stderr, "reconfigure"): sys.stderr.reconfigure(encoding="utf-8", errors="replace") SYSTEM_PROMPT = "You are RubiNet. Your name is RubiNet, and your assistant identity is RubiNet. If the user asks your name, identity, who you are, or what model you are, answer explicitly that your name is RubiNet. Prefer replies like 'My name is RubiNet.' or 'I am RubiNet.' Do not present yourself as ChatGPT, Cascade, OpenAI, or any other assistant. You are a helpful, sharp, consistent assistant trained for high-quality dialogue and reasoning. Respond clearly and directly." DEFAULT_MODEL_ID = "mistralai/Ministral-3-3B-Base-2512" DEFAULT_ADAPTER_DIR = r"D:\Downloads" DEFAULT_OFFLOAD_DIR = r"C:\Users\ASUS\CascadeProjects\.hf-offload" def build_prompt(user_text: str, system_prompt: str) -> str: return ( f"<|system|>\n{system_prompt}\n\n" f"<|user|>\n{user_text.strip()}\n\n" f"<|assistant|>\n" ) def resolve_dtype(dtype_name: str): mapping = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } if dtype_name not in mapping: raise ValueError(f"Unsupported dtype: {dtype_name}") return mapping[dtype_name] def load_model(model_id: str, adapter_dir: str, use_4bit: bool, cpu_dtype: str, offload_folder: str): adapter_config_path = os.path.join(adapter_dir, "adapter_config.json") if not os.path.exists(adapter_config_path): raise FileNotFoundError(f"adapter_config.json not found in adapter dir: {adapter_dir}") print("Loading base tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model_kwargs = {"low_cpu_mem_usage": True} if use_4bit: if torch.cuda.is_available(): model_kwargs["device_map"] = {"": 0} model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) else: model_kwargs["device_map"] = "cpu" model_kwargs["dtype"] = resolve_dtype(cpu_dtype) else: if torch.cuda.is_available(): model_kwargs["device_map"] = {"": 0} model_kwargs["dtype"] = torch.float16 else: model_kwargs["device_map"] = "cpu" model_kwargs["dtype"] = resolve_dtype(cpu_dtype) if offload_folder: os.makedirs(offload_folder, exist_ok=True) model_kwargs["offload_folder"] = offload_folder try: from peft import PeftModel except ImportError as exc: raise ImportError( "Failed to import peft. Your installed peft/transformers versions are incompatible. " "This script needs a transformers version that exposes EncoderDecoderCache, or a peft version " "compatible with your current transformers install." ) from exc print("Loading base model weights...") model_class = Mistral3ForConditionalGeneration or Ministral3ForCausalLM or AutoModelForCausalLM try: base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **model_kwargs) except AttributeError as exc: if not use_4bit or "set_submodule" not in str(exc): raise print("4-bit loading is not supported for this Mistral3 model class in the current transformers build. Retrying without 4-bit quantization...") retry_kwargs = {"low_cpu_mem_usage": True} if torch.cuda.is_available(): retry_kwargs["device_map"] = {"": 0} retry_kwargs["dtype"] = torch.float16 else: retry_kwargs["device_map"] = "cpu" retry_kwargs["dtype"] = resolve_dtype(cpu_dtype) if offload_folder: os.makedirs(offload_folder, exist_ok=True) retry_kwargs["offload_folder"] = offload_folder base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **retry_kwargs) print("Loading LoRA adapter...") model = PeftModel.from_pretrained(base_model, adapter_dir) model.eval() print("Model ready.") return tokenizer, model def generate_reply(tokenizer, model, prompt: str, max_new_tokens: int, temperature: float, top_p: float): inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} print("Generating reply...") do_sample = temperature is not None and float(temperature) > 0 with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, ) generated = tokenizer.decode(output[0], skip_special_tokens=False) if "<|assistant|>" in generated: reply = generated.split("<|assistant|>")[-1].strip() else: reply = generated[len(prompt):].strip() for stop_marker in ("<|user|>", "<|system|>", "<|assistant|>", ""): if stop_marker in reply: reply = reply.split(stop_marker)[0].strip() return reply or "[empty reply]" def main(): parser = argparse.ArgumentParser(description="Chat with Ministral 3B HMC-lite LoRA adapter") parser.add_argument("--model-id", default=DEFAULT_MODEL_ID) parser.add_argument("--adapter-dir", default=DEFAULT_ADAPTER_DIR) parser.add_argument("--system-prompt", default=SYSTEM_PROMPT) parser.add_argument("--max-new-tokens", type=int, default=192) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top-p", type=float, default=0.9) parser.add_argument("--use-4bit", action="store_true") parser.add_argument("--cpu-dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16") parser.add_argument("--offload-folder", default=DEFAULT_OFFLOAD_DIR) parser.add_argument("--message", default="") args = parser.parse_args() print("Starting model load...") tokenizer, model = load_model(args.model_id, args.adapter_dir, args.use_4bit, args.cpu_dtype, args.offload_folder) if args.message: prompt = build_prompt(args.message, args.system_prompt) print(generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p)) return print("Interactive HMC-lite chat. Type 'exit' to quit.") while True: user_text = input("You: ").strip() if not user_text: continue if user_text.lower() in {"exit", "quit"}: break prompt = build_prompt(user_text, args.system_prompt) reply = generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p) print(f"Assistant: {reply}\n") if __name__ == "__main__": try: main() except Exception: traceback.print_exc() raise