DevHunterAI commited on
Commit
1335cf8
·
verified ·
1 Parent(s): 7bd5276

Upload ministral_3b_hmc_chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ministral_3b_hmc_chat.py +186 -0
ministral_3b_hmc_chat.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import traceback
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
+
9
+ try:
10
+ from transformers import Mistral3ForConditionalGeneration
11
+ except ImportError:
12
+ Mistral3ForConditionalGeneration = None
13
+
14
+ try:
15
+ from transformers import Ministral3ForCausalLM
16
+ except ImportError:
17
+ Ministral3ForCausalLM = None
18
+
19
+ if hasattr(sys.stdout, "reconfigure"):
20
+ sys.stdout.reconfigure(encoding="utf-8", errors="replace")
21
+ if hasattr(sys.stderr, "reconfigure"):
22
+ sys.stderr.reconfigure(encoding="utf-8", errors="replace")
23
+
24
+ SYSTEM_PROMPT = "You are RubiNet. Your name is RubiNet, and your assistant identity is RubiNet. If the user asks your name, identity, who you are, or what model you are, answer explicitly that your name is RubiNet. Prefer replies like 'My name is RubiNet.' or 'I am RubiNet.' Do not present yourself as ChatGPT, Cascade, OpenAI, or any other assistant. You are a helpful, sharp, consistent assistant trained for high-quality dialogue and reasoning. Respond clearly and directly."
25
+ DEFAULT_MODEL_ID = "mistralai/Ministral-3-3B-Base-2512"
26
+ DEFAULT_ADAPTER_DIR = r"D:\Downloads"
27
+ DEFAULT_OFFLOAD_DIR = r"C:\Users\ASUS\CascadeProjects\.hf-offload"
28
+
29
+
30
+ def build_prompt(user_text: str, system_prompt: str) -> str:
31
+ return (
32
+ f"<|system|>\n{system_prompt}\n\n"
33
+ f"<|user|>\n{user_text.strip()}\n\n"
34
+ f"<|assistant|>\n"
35
+ )
36
+
37
+
38
+ def resolve_dtype(dtype_name: str):
39
+ mapping = {
40
+ "float32": torch.float32,
41
+ "float16": torch.float16,
42
+ "bfloat16": torch.bfloat16,
43
+ }
44
+ if dtype_name not in mapping:
45
+ raise ValueError(f"Unsupported dtype: {dtype_name}")
46
+ return mapping[dtype_name]
47
+
48
+
49
+ def load_model(model_id: str, adapter_dir: str, use_4bit: bool, cpu_dtype: str, offload_folder: str):
50
+ adapter_config_path = os.path.join(adapter_dir, "adapter_config.json")
51
+ if not os.path.exists(adapter_config_path):
52
+ raise FileNotFoundError(f"adapter_config.json not found in adapter dir: {adapter_dir}")
53
+
54
+ print("Loading base tokenizer...")
55
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
56
+ if tokenizer.pad_token is None:
57
+ tokenizer.pad_token = tokenizer.eos_token
58
+
59
+ model_kwargs = {"low_cpu_mem_usage": True}
60
+ if use_4bit:
61
+ if torch.cuda.is_available():
62
+ model_kwargs["device_map"] = {"": 0}
63
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
64
+ load_in_4bit=True,
65
+ bnb_4bit_quant_type="nf4",
66
+ bnb_4bit_compute_dtype=torch.bfloat16,
67
+ bnb_4bit_use_double_quant=True,
68
+ )
69
+ else:
70
+ model_kwargs["device_map"] = "cpu"
71
+ model_kwargs["dtype"] = resolve_dtype(cpu_dtype)
72
+ else:
73
+ if torch.cuda.is_available():
74
+ model_kwargs["device_map"] = {"": 0}
75
+ model_kwargs["dtype"] = torch.float16
76
+ else:
77
+ model_kwargs["device_map"] = "cpu"
78
+ model_kwargs["dtype"] = resolve_dtype(cpu_dtype)
79
+ if offload_folder:
80
+ os.makedirs(offload_folder, exist_ok=True)
81
+ model_kwargs["offload_folder"] = offload_folder
82
+
83
+ try:
84
+ from peft import PeftModel
85
+ except ImportError as exc:
86
+ raise ImportError(
87
+ "Failed to import peft. Your installed peft/transformers versions are incompatible. "
88
+ "This script needs a transformers version that exposes EncoderDecoderCache, or a peft version "
89
+ "compatible with your current transformers install."
90
+ ) from exc
91
+
92
+ print("Loading base model weights...")
93
+ model_class = Mistral3ForConditionalGeneration or Ministral3ForCausalLM or AutoModelForCausalLM
94
+ try:
95
+ base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **model_kwargs)
96
+ except AttributeError as exc:
97
+ if not use_4bit or "set_submodule" not in str(exc):
98
+ raise
99
+ print("4-bit loading is not supported for this Mistral3 model class in the current transformers build. Retrying without 4-bit quantization...")
100
+ retry_kwargs = {"low_cpu_mem_usage": True}
101
+ if torch.cuda.is_available():
102
+ retry_kwargs["device_map"] = {"": 0}
103
+ retry_kwargs["dtype"] = torch.float16
104
+ else:
105
+ retry_kwargs["device_map"] = "cpu"
106
+ retry_kwargs["dtype"] = resolve_dtype(cpu_dtype)
107
+ if offload_folder:
108
+ os.makedirs(offload_folder, exist_ok=True)
109
+ retry_kwargs["offload_folder"] = offload_folder
110
+ base_model = model_class.from_pretrained(model_id, trust_remote_code=True, **retry_kwargs)
111
+ print("Loading LoRA adapter...")
112
+ model = PeftModel.from_pretrained(base_model, adapter_dir)
113
+ model.eval()
114
+ print("Model ready.")
115
+ return tokenizer, model
116
+
117
+
118
+ def generate_reply(tokenizer, model, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
119
+ inputs = tokenizer(prompt, return_tensors="pt")
120
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
121
+
122
+ print("Generating reply...")
123
+ do_sample = temperature is not None and float(temperature) > 0
124
+ with torch.inference_mode():
125
+ output = model.generate(
126
+ **inputs,
127
+ max_new_tokens=max_new_tokens,
128
+ do_sample=do_sample,
129
+ temperature=temperature,
130
+ top_p=top_p,
131
+ pad_token_id=tokenizer.eos_token_id,
132
+ eos_token_id=tokenizer.eos_token_id,
133
+ use_cache=True,
134
+ )
135
+
136
+ generated = tokenizer.decode(output[0], skip_special_tokens=False)
137
+ if "<|assistant|>" in generated:
138
+ reply = generated.split("<|assistant|>")[-1].strip()
139
+ else:
140
+ reply = generated[len(prompt):].strip()
141
+ for stop_marker in ("<|user|>", "<|system|>", "<|assistant|>", "</s>"):
142
+ if stop_marker in reply:
143
+ reply = reply.split(stop_marker)[0].strip()
144
+ return reply or "[empty reply]"
145
+
146
+
147
+ def main():
148
+ parser = argparse.ArgumentParser(description="Chat with Ministral 3B HMC-lite LoRA adapter")
149
+ parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
150
+ parser.add_argument("--adapter-dir", default=DEFAULT_ADAPTER_DIR)
151
+ parser.add_argument("--system-prompt", default=SYSTEM_PROMPT)
152
+ parser.add_argument("--max-new-tokens", type=int, default=192)
153
+ parser.add_argument("--temperature", type=float, default=0.7)
154
+ parser.add_argument("--top-p", type=float, default=0.9)
155
+ parser.add_argument("--use-4bit", action="store_true")
156
+ parser.add_argument("--cpu-dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16")
157
+ parser.add_argument("--offload-folder", default=DEFAULT_OFFLOAD_DIR)
158
+ parser.add_argument("--message", default="")
159
+ args = parser.parse_args()
160
+
161
+ print("Starting model load...")
162
+ tokenizer, model = load_model(args.model_id, args.adapter_dir, args.use_4bit, args.cpu_dtype, args.offload_folder)
163
+
164
+ if args.message:
165
+ prompt = build_prompt(args.message, args.system_prompt)
166
+ print(generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p))
167
+ return
168
+
169
+ print("Interactive HMC-lite chat. Type 'exit' to quit.")
170
+ while True:
171
+ user_text = input("You: ").strip()
172
+ if not user_text:
173
+ continue
174
+ if user_text.lower() in {"exit", "quit"}:
175
+ break
176
+ prompt = build_prompt(user_text, args.system_prompt)
177
+ reply = generate_reply(tokenizer, model, prompt, args.max_new_tokens, args.temperature, args.top_p)
178
+ print(f"Assistant: {reply}\n")
179
+
180
+
181
+ if __name__ == "__main__":
182
+ try:
183
+ main()
184
+ except Exception:
185
+ traceback.print_exc()
186
+ raise