Prithvik-1 commited on
Commit
3ba49d5
·
verified ·
1 Parent(s): eac8397

Upload models/msp/inference/inference_mistral7b.py with huggingface_hub

Browse files
models/msp/inference/inference_mistral7b.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for Mistral 7B
4
+ Supports both Ollama and local fine-tuned models
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import requests
11
+ import json
12
+ import time
13
+ from typing import Optional, List
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
15
+ from peft import PeftModel
16
+ import torch
17
+ from threading import Thread
18
+
19
+ # Configuration
20
+ DEFAULT_OLLAMA_URL = "http://localhost:11434"
21
+ OLLAMA_MODEL_NAME = "mistral:7b"
22
+
23
+ def get_device_info():
24
+ """Detect and return available compute device"""
25
+ device_info = {
26
+ "device": "cpu",
27
+ "device_type": "cpu",
28
+ "use_quantization": False,
29
+ "dtype": torch.float32
30
+ }
31
+
32
+ if torch.cuda.is_available():
33
+ device_info["device"] = "cuda"
34
+ device_info["device_type"] = "cuda"
35
+ device_info["use_quantization"] = True
36
+ device_info["dtype"] = torch.float16
37
+ device_info["device_count"] = torch.cuda.device_count()
38
+ device_info["device_name"] = torch.cuda.get_device_name(0)
39
+ if device_info["device_count"] > 1:
40
+ print(f"✓ {device_info['device_count']} CUDA GPUs detected:")
41
+ for i in range(device_info["device_count"]):
42
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
43
+ print(f" Model will be automatically distributed across all GPUs")
44
+ else:
45
+ print(f"✓ CUDA GPU detected: {device_info['device_name']}")
46
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
47
+ device_info["device"] = "mps"
48
+ device_info["device_type"] = "mps"
49
+ device_info["use_quantization"] = False # BitsAndBytes doesn't support MPS
50
+ device_info["dtype"] = torch.float16
51
+ print("✓ Apple Silicon GPU (MPS) detected")
52
+ else:
53
+ print("⚠ No GPU detected, using CPU (inference will be slow)")
54
+ device_info["dtype"] = torch.float32
55
+
56
+ return device_info
57
+
58
+ def load_local_model(model_path: str, use_quantization: Optional[bool] = None):
59
+ """Load a fine-tuned model from local path"""
60
+ device_info = get_device_info()
61
+ print(f"\nLoading model from: {model_path}")
62
+
63
+ # Determine quantization based on device if not explicitly set
64
+ if use_quantization is None:
65
+ use_quantization = device_info["use_quantization"]
66
+
67
+ # Load tokenizer
68
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
69
+ if tokenizer.pad_token is None:
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+
72
+ # Check if it's a LoRA adapter
73
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
74
+ is_lora = os.path.exists(adapter_config_path)
75
+
76
+ # Prepare model loading kwargs
77
+ def get_model_kwargs(quantize=False):
78
+ kwargs = {"trust_remote_code": True}
79
+ if quantize and device_info["device_type"] == "cuda":
80
+ kwargs["quantization_config"] = BitsAndBytesConfig(
81
+ load_in_4bit=True,
82
+ bnb_4bit_quant_type="nf4",
83
+ bnb_4bit_compute_dtype=torch.float16,
84
+ )
85
+ kwargs["device_map"] = "auto"
86
+ else:
87
+ kwargs["torch_dtype"] = device_info["dtype"]
88
+ if device_info["device_type"] == "mps":
89
+ kwargs["device_map"] = "auto"
90
+ elif device_info["device_type"] == "cuda":
91
+ kwargs["device_map"] = "auto"
92
+ else:
93
+ kwargs["device_map"] = "cpu"
94
+ return kwargs
95
+
96
+ if is_lora:
97
+ # Load base model - prefer local model to avoid cache issues
98
+ local_base_model = "/workspace/ftt/base_models/Mistral-7B-v0.1"
99
+
100
+ # Check if local model exists, otherwise use HuggingFace
101
+ if os.path.exists(local_base_model):
102
+ base_model_name = local_base_model
103
+ print(f"Loading base model from local: {base_model_name}")
104
+ else:
105
+ base_model_name = "mistralai/Mistral-7B-v0.1"
106
+ print(f"Loading base model from HuggingFace: {base_model_name}")
107
+
108
+ base_model = AutoModelForCausalLM.from_pretrained(
109
+ base_model_name,
110
+ local_files_only=os.path.exists(local_base_model),
111
+ **get_model_kwargs(use_quantization)
112
+ )
113
+
114
+ # Load LoRA adapter
115
+ print("Loading LoRA adapter...")
116
+ model = PeftModel.from_pretrained(base_model, model_path)
117
+ model = model.merge_and_unload() # Merge adapter weights
118
+ else:
119
+ # Load full model
120
+ model = AutoModelForCausalLM.from_pretrained(
121
+ model_path,
122
+ **get_model_kwargs(use_quantization)
123
+ )
124
+
125
+ model.eval()
126
+
127
+ # Report device placement for multi-GPU setups
128
+ if device_info["device_type"] == "cuda" and device_info.get("device_count", 1) > 1:
129
+ print(f"\nMulti-GPU Model Distribution:")
130
+ for name, module in model.named_modules():
131
+ if hasattr(module, 'weight') and module.weight is not None:
132
+ device = next(module.parameters()).device
133
+ if device.type == 'cuda':
134
+ print(f" {name[:50]:<50} -> GPU {device.index}")
135
+ break # Just show first layer's device
136
+ print(f" (Model automatically split across {device_info['device_count']} GPUs)")
137
+ else:
138
+ print(f"Model loaded successfully on {device_info['device']}!")
139
+
140
+ return model, tokenizer
141
+
142
+ def generate_with_local_model(model, tokenizer, prompt: str, max_length: int = 512, temperature: float = 0.7, stream: bool = False):
143
+ """Generate text using local model"""
144
+ # Use prompt as-is - don't reformat it
145
+ # The user should provide the prompt in the correct format for their model
146
+ formatted_prompt = prompt
147
+
148
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
149
+
150
+ if stream:
151
+ # Streaming generation
152
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
153
+ generation_kwargs = dict(
154
+ **inputs,
155
+ max_new_tokens=max_length, # Use max_new_tokens instead of max_length
156
+ temperature=temperature,
157
+ do_sample=True,
158
+ top_p=0.9,
159
+ repetition_penalty=1.1, # Prevent repetition
160
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
161
+ eos_token_id=tokenizer.eos_token_id,
162
+ streamer=streamer,
163
+ )
164
+
165
+ # Start generation in a separate thread
166
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
167
+ thread.start()
168
+
169
+ # Stream the output
170
+ generated_text = ""
171
+ token_count = 0
172
+ start_time = time.time()
173
+
174
+ for text in streamer:
175
+ generated_text += text
176
+ token_count += 1
177
+ print(text, end="", flush=True)
178
+
179
+ thread.join()
180
+
181
+ end_time = time.time()
182
+ elapsed_time = end_time - start_time
183
+ tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
184
+
185
+ # Extract only the generated part (after the prompt)
186
+ if prompt in generated_text:
187
+ response = generated_text[len(prompt):].strip()
188
+ else:
189
+ response = generated_text.strip()
190
+
191
+ return response, token_count, elapsed_time, tokens_per_second
192
+ else:
193
+ # Non-streaming generation (original behavior)
194
+ with torch.no_grad():
195
+ outputs = model.generate(
196
+ **inputs,
197
+ max_new_tokens=max_length, # Use max_new_tokens instead of max_length
198
+ temperature=temperature,
199
+ do_sample=True,
200
+ top_p=0.9,
201
+ repetition_penalty=1.1, # Prevent repetition
202
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
203
+ eos_token_id=tokenizer.eos_token_id,
204
+ )
205
+
206
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
207
+ # Extract only the generated part (after the prompt)
208
+ if prompt in generated_text:
209
+ response = generated_text[len(prompt):].strip()
210
+ else:
211
+ response = generated_text.strip()
212
+ return response
213
+
214
+ def generate_with_ollama(prompt: str, model_name: str = OLLAMA_MODEL_NAME, url: str = DEFAULT_OLLAMA_URL, max_tokens: int = 512, temperature: float = 0.7):
215
+ """Generate text using Ollama API"""
216
+ formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
217
+
218
+ try:
219
+ response = requests.post(
220
+ f"{url}/api/generate",
221
+ json={
222
+ "model": model_name,
223
+ "prompt": formatted_prompt,
224
+ "stream": False,
225
+ "options": {
226
+ "temperature": temperature,
227
+ "num_predict": max_tokens,
228
+ }
229
+ },
230
+ timeout=120
231
+ )
232
+ response.raise_for_status()
233
+ result = response.json()
234
+ generated_text = result.get("response", "")
235
+
236
+ # Extract only the response part
237
+ response_text = generated_text.split("### Response:\n")[-1].strip()
238
+ return response_text
239
+ except requests.exceptions.ConnectionError:
240
+ print(f"Error: Could not connect to Ollama at {url}")
241
+ print("Make sure Ollama is running. Start it with: ollama serve")
242
+ sys.exit(1)
243
+ except requests.exceptions.RequestException as e:
244
+ print(f"Error calling Ollama API: {e}")
245
+ sys.exit(1)
246
+
247
+ def interactive_mode(use_ollama: bool, model_path: Optional[str] = None, ollama_model: str = OLLAMA_MODEL_NAME, ollama_url: str = DEFAULT_OLLAMA_URL, use_quantization: Optional[bool] = None):
248
+ """Run interactive inference session"""
249
+ model = None
250
+ tokenizer = None
251
+
252
+ if not use_ollama:
253
+ if not model_path:
254
+ print("Error: no model path provided for local mode")
255
+ sys.exit(1)
256
+ if not os.path.exists(model_path) and "/" not in model_path:
257
+ print(f"Error: Model path {model_path} does not exist")
258
+ sys.exit(1)
259
+ model, tokenizer = load_local_model(model_path, use_quantization)
260
+
261
+ print("\n" + "=" * 50)
262
+ print("Mistral 7B Interactive Inference")
263
+ print("Type 'quit' or 'exit' to stop")
264
+ print("=" * 50 + "\n")
265
+
266
+ while True:
267
+ try:
268
+ user_input = input("You: ").strip()
269
+
270
+ if user_input.lower() in ['quit', 'exit', 'q']:
271
+ print("Goodbye!")
272
+ break
273
+
274
+ if not user_input:
275
+ continue
276
+
277
+ print("\nAssistant: ", end="", flush=True)
278
+
279
+ if use_ollama:
280
+ start_time = time.time()
281
+ response = generate_with_ollama(user_input, ollama_model, ollama_url)
282
+ end_time = time.time()
283
+ inference_time = end_time - start_time
284
+ print(response)
285
+ print(f"\n⏱️ Inference time: {inference_time:.2f} seconds")
286
+ else:
287
+ # Use streaming for local model
288
+ response, token_count, elapsed_time, tokens_per_second = generate_with_local_model(
289
+ model, tokenizer, user_input, stream=True
290
+ )
291
+ print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec")
292
+
293
+ print()
294
+
295
+ except KeyboardInterrupt:
296
+ print("\n\nGoodbye!")
297
+ break
298
+ except Exception as e:
299
+ print(f"\nError: {e}")
300
+
301
+ def single_inference(prompt: str, use_ollama: bool, model_path: Optional[str] = None, ollama_model: str = OLLAMA_MODEL_NAME, ollama_url: str = DEFAULT_OLLAMA_URL, use_quantization: Optional[bool] = None):
302
+ """Run a single inference"""
303
+
304
+ if use_ollama:
305
+ start_time = time.time()
306
+ response = generate_with_ollama(prompt, ollama_model, ollama_url)
307
+ end_time = time.time()
308
+ inference_time = end_time - start_time
309
+ print(response)
310
+ print(f"\n⏱️ Inference time: {inference_time:.2f} seconds")
311
+ else:
312
+ if not model_path:
313
+ print("Error: no model path provided for local mode")
314
+ sys.exit(1)
315
+ if not os.path.exists(model_path) and "/" not in model_path:
316
+ print(f"Error: Model path {model_path} does not exist")
317
+ sys.exit(1)
318
+ model, tokenizer = load_local_model(model_path, use_quantization)
319
+
320
+ # Use streaming for local model
321
+ response, token_count, elapsed_time, tokens_per_second = generate_with_local_model(
322
+ model, tokenizer, prompt, stream=True
323
+ )
324
+ print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec")
325
+
326
+ def main():
327
+ parser = argparse.ArgumentParser(description="Mistral 7B Inference Script")
328
+ parser.add_argument(
329
+ "--mode",
330
+ choices=["local", "ollama"],
331
+ default="ollama",
332
+ help="Inference mode: local (fine-tuned model) or ollama (Ollama API)"
333
+ )
334
+ parser.add_argument(
335
+ "--model-path",
336
+ type=str,
337
+ default="./mistral7b-finetuned-ahb2apb",
338
+ help="Path to fine-tuned model (for local mode)"
339
+ )
340
+ parser.add_argument(
341
+ "--ollama-model",
342
+ type=str,
343
+ default=OLLAMA_MODEL_NAME,
344
+ help="Ollama model name (default: mistral:7b)"
345
+ )
346
+ parser.add_argument(
347
+ "--ollama-url",
348
+ type=str,
349
+ default=DEFAULT_OLLAMA_URL,
350
+ help="Ollama API URL (default: http://localhost:11434)"
351
+ )
352
+ parser.add_argument(
353
+ "--prompt",
354
+ type=str,
355
+ help="Single prompt to process (if not provided, runs in interactive mode)"
356
+ )
357
+ parser.add_argument(
358
+ "--no-quantization",
359
+ action="store_true",
360
+ help="Disable quantization for local models (requires more memory)"
361
+ )
362
+
363
+ args = parser.parse_args()
364
+
365
+ use_ollama = args.mode == "ollama"
366
+ use_quantization = False if args.no_quantization else None # Auto-detect based on device unless disabled
367
+
368
+ if args.prompt:
369
+ if use_ollama:
370
+ start_time = time.time()
371
+ response = generate_with_ollama(args.prompt, args.ollama_model, args.ollama_url)
372
+ end_time = time.time()
373
+ inference_time = end_time - start_time
374
+ print(response)
375
+ print(f"\n⏱️ Inference time: {inference_time:.2f} seconds")
376
+ else:
377
+ if not args.model_path:
378
+ print("Error: no model path provided for local mode")
379
+ sys.exit(1)
380
+ if not os.path.exists(args.model_path) and "/" not in args.model_path:
381
+ print(f"Error: Model path {args.model_path} does not exist")
382
+ sys.exit(1)
383
+ model, tokenizer = load_local_model(args.model_path, use_quantization)
384
+
385
+ # Use streaming for local model
386
+ response, token_count, elapsed_time, tokens_per_second = generate_with_local_model(
387
+ model, tokenizer, args.prompt, stream=True
388
+ )
389
+ print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec")
390
+ else:
391
+ interactive_mode(
392
+ use_ollama,
393
+ args.model_path if not use_ollama else None,
394
+
395
+ args.ollama_model,
396
+ args.ollama_url,
397
+ use_quantization
398
+ )
399
+
400
+ if __name__ == "__main__":
401
+ main()
402
+