Prithvik-1 commited on
Commit
49365b9
·
verified ·
1 Parent(s): d048517

Upload scripts/inference/inference_codellama.py with huggingface_hub

Browse files
scripts/inference/inference_codellama.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for CodeLlama 7B
4
+ Supports both Ollama and local fine-tuned models
5
+ Updated for CodeLlama fine-tuned models
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import argparse
11
+ import requests
12
+ import json
13
+ import time
14
+ from typing import Optional, List
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
16
+ from peft import PeftModel
17
+ import torch
18
+ from threading import Thread
19
+ from pathlib import Path
20
+
21
+ # Get script directory for relative paths
22
+ SCRIPT_DIR = Path(__file__).parent.parent.parent
23
+
24
+ # Configuration
25
+ DEFAULT_OLLAMA_URL = "http://localhost:11434"
26
+ OLLAMA_MODEL_NAME = "codellama:7b"
27
+ DEFAULT_BASE_MODEL = str(SCRIPT_DIR / "models" / "base-models" / "CodeLlama-7B-Instruct")
28
+ DEFAULT_FINETUNED_MODEL = str(SCRIPT_DIR / "training-outputs" / "codellama-fifo-v1")
29
+
30
+ def extract_code_from_response(text: str) -> str:
31
+ """
32
+ Extract Verilog code from markdown code blocks.
33
+ Handles both ```verilog and generic ``` markers.
34
+ """
35
+ if not text:
36
+ return text
37
+
38
+ # Check for verilog code block
39
+ if '```verilog' in text:
40
+ start = text.find('```verilog') + len('```verilog')
41
+ end = text.find('```', start)
42
+ if end != -1:
43
+ extracted = text[start:end].strip()
44
+ return extracted
45
+
46
+ # Check for generic code block
47
+ if '```' in text:
48
+ # Find first code block
49
+ start = text.find('```')
50
+ if start != -1:
51
+ # Find end of language identifier
52
+ start_marker = text.find('\n', start)
53
+ if start_marker == -1:
54
+ start_marker = start + 3
55
+ else:
56
+ start_marker += 1
57
+
58
+ # Find closing marker
59
+ end = text.find('```', start_marker)
60
+ if end != -1:
61
+ extracted = text[start_marker:end].strip()
62
+ return extracted
63
+
64
+ # No markers found, return as-is (might be pure code already)
65
+ return text.strip()
66
+
67
+ def get_device_info():
68
+ """Detect and return available compute device"""
69
+ device_info = {
70
+ "device": "cpu",
71
+ "device_type": "cpu",
72
+ "use_quantization": False,
73
+ "dtype": torch.float32
74
+ }
75
+
76
+ if torch.cuda.is_available():
77
+ device_info["device"] = "cuda"
78
+ device_info["device_type"] = "cuda"
79
+ device_info["use_quantization"] = True
80
+ device_info["dtype"] = torch.float16
81
+ device_info["device_count"] = torch.cuda.device_count()
82
+ device_info["device_name"] = torch.cuda.get_device_name(0)
83
+ if device_info["device_count"] > 1:
84
+ print(f"✓ {device_info['device_count']} CUDA GPUs detected:")
85
+ for i in range(device_info["device_count"]):
86
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
87
+ print(f" Model will be automatically distributed across all GPUs")
88
+ else:
89
+ print(f"✓ CUDA GPU detected: {device_info['device_name']}")
90
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
91
+ device_info["device"] = "mps"
92
+ device_info["device_type"] = "mps"
93
+ device_info["use_quantization"] = False # BitsAndBytes doesn't support MPS
94
+ device_info["dtype"] = torch.float16
95
+ print("✓ Apple Silicon GPU (MPS) detected")
96
+ else:
97
+ print("⚠ No GPU detected, using CPU (inference will be slow)")
98
+ device_info["dtype"] = torch.float32
99
+
100
+ return device_info
101
+
102
+ def load_local_model(model_path: str, base_model_path: Optional[str] = None, use_quantization: Optional[bool] = None, merge_weights: bool = False):
103
+ """Load a fine-tuned CodeLlama model from local path"""
104
+ device_info = get_device_info()
105
+ print(f"\nLoading model from: {model_path}")
106
+
107
+ # Determine quantization based on device if not explicitly set
108
+ if use_quantization is None:
109
+ use_quantization = device_info["use_quantization"]
110
+
111
+ # Load tokenizer (try from model path first, fallback to base model)
112
+ tokenizer_path = model_path
113
+ if not os.path.exists(os.path.join(model_path, "tokenizer_config.json")):
114
+ if base_model_path and os.path.exists(base_model_path):
115
+ tokenizer_path = base_model_path
116
+ else:
117
+ tokenizer_path = DEFAULT_BASE_MODEL
118
+
119
+ print(f"Loading tokenizer from: {tokenizer_path}")
120
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
121
+ if tokenizer.pad_token is None:
122
+ tokenizer.pad_token = tokenizer.eos_token
123
+ tokenizer.pad_token_id = tokenizer.eos_token_id
124
+
125
+ # Check if it's a LoRA adapter
126
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
127
+ is_lora = os.path.exists(adapter_config_path)
128
+
129
+ # Prepare model loading kwargs
130
+ def get_model_kwargs(quantize=False):
131
+ kwargs = {"trust_remote_code": True}
132
+ if quantize and device_info["device_type"] == "cuda":
133
+ kwargs["quantization_config"] = BitsAndBytesConfig(
134
+ load_in_4bit=True,
135
+ bnb_4bit_quant_type="nf4",
136
+ bnb_4bit_compute_dtype=torch.float16,
137
+ bnb_4bit_use_double_quant=True,
138
+ )
139
+ kwargs["device_map"] = "auto"
140
+ else:
141
+ kwargs["torch_dtype"] = device_info["dtype"]
142
+ if device_info["device_type"] == "mps":
143
+ kwargs["device_map"] = "auto"
144
+ elif device_info["device_type"] == "cuda":
145
+ kwargs["device_map"] = "auto"
146
+ else:
147
+ kwargs["device_map"] = "cpu"
148
+ return kwargs
149
+
150
+ if is_lora:
151
+ # Determine base model path
152
+ if base_model_path and os.path.exists(base_model_path):
153
+ base_model_name = base_model_path
154
+ print(f"Loading base model from specified path: {base_model_name}")
155
+ elif os.path.exists(DEFAULT_BASE_MODEL):
156
+ base_model_name = DEFAULT_BASE_MODEL
157
+ print(f"Loading base model from default path: {base_model_name}")
158
+ else:
159
+ # Try to read from training config
160
+ config_path = os.path.join(model_path, "training_config.json")
161
+ if os.path.exists(config_path):
162
+ with open(config_path, 'r') as f:
163
+ config = json.load(f)
164
+ base_model_name = config.get("base_model", "codellama/CodeLlama-7b-Instruct-hf")
165
+ print(f"Loading base model from training config: {base_model_name}")
166
+ else:
167
+ base_model_name = "codellama/CodeLlama-7b-Instruct-hf"
168
+ print(f"Loading base model from HuggingFace: {base_model_name}")
169
+
170
+ # Load base model
171
+ base_model = AutoModelForCausalLM.from_pretrained(
172
+ base_model_name,
173
+ local_files_only=os.path.exists(base_model_name) and not base_model_name.startswith("codellama/"),
174
+ **get_model_kwargs(use_quantization)
175
+ )
176
+
177
+ # Load LoRA adapter
178
+ print("Loading LoRA adapter...")
179
+ model = PeftModel.from_pretrained(base_model, model_path)
180
+
181
+ if merge_weights:
182
+ print("Merging LoRA weights into base model...")
183
+ model = model.merge_and_unload()
184
+ else:
185
+ print("Using LoRA adapter (weights not merged - faster loading)")
186
+ else:
187
+ # Load full model
188
+ model = AutoModelForCausalLM.from_pretrained(
189
+ model_path,
190
+ **get_model_kwargs(use_quantization)
191
+ )
192
+
193
+ model.eval()
194
+
195
+ # Report device placement for multi-GPU setups
196
+ if device_info["device_type"] == "cuda" and device_info.get("device_count", 1) > 1:
197
+ print(f"\nMulti-GPU Model Distribution:")
198
+ for name, module in model.named_modules():
199
+ if hasattr(module, 'weight') and module.weight is not None:
200
+ device = next(module.parameters()).device
201
+ if device.type == 'cuda':
202
+ print(f" {name[:50]:<50} -> GPU {device.index}")
203
+ break # Just show first layer's device
204
+ print(f" (Model automatically split across {device_info['device_count']} GPUs)")
205
+ else:
206
+ print(f"✅ Model loaded successfully on {device_info['device']}!")
207
+
208
+ return model, tokenizer
209
+
210
+ def generate_with_local_model(model, tokenizer, prompt: str, max_new_tokens: int = 800, temperature: float = 0.3, stream: bool = False, use_chat_template: bool = True):
211
+ """Generate text using local CodeLlama model"""
212
+ # Check if prompt is already in chat template format (contains [INST] or </s>)
213
+ if use_chat_template and ("[INST]" not in prompt and "</s>" not in prompt):
214
+ # Prompt is not in chat format - need to convert it
215
+ # Extract system prompt and user message from prompt
216
+ # Assume format: "System prompt...\n\nUser task"
217
+ parts = prompt.split("\n\n", 1)
218
+ if len(parts) == 2:
219
+ system_message = parts[0].strip()
220
+ user_message = parts[1].strip()
221
+ else:
222
+ # Default system prompt
223
+ system_message = "You are Elinnos RTL Code Generator v1.0, a specialized Verilog/SystemVerilog code generation agent. Your role: Generate clean, synthesizable RTL code for hardware design tasks. Output ONLY functional RTL code with no $display, assertions, comments, or debug statements."
224
+ user_message = prompt
225
+
226
+ # Apply CodeLlama chat template
227
+ messages = [
228
+ {"role": "system", "content": system_message},
229
+ {"role": "user", "content": user_message}
230
+ ]
231
+ formatted_prompt = tokenizer.apply_chat_template(
232
+ messages,
233
+ tokenize=False,
234
+ add_generation_prompt=True # Adds [/INST] at the end
235
+ )
236
+ else:
237
+ # Prompt is already in chat template format or chat template disabled
238
+ formatted_prompt = prompt
239
+
240
+ # For CodeLlama, use max_length for input
241
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=1536).to(model.device)
242
+
243
+ if stream:
244
+ # Streaming generation
245
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
246
+ generation_kwargs = dict(
247
+ **inputs,
248
+ max_new_tokens=max_new_tokens, # Use max_new_tokens for CodeLlama
249
+ temperature=temperature,
250
+ do_sample=temperature > 0, # Use sampling if temperature > 0
251
+ top_p=0.9 if temperature > 0 else None,
252
+ repetition_penalty=1.2, # Higher penalty to prevent repetition (was 1.1)
253
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
254
+ eos_token_id=tokenizer.eos_token_id,
255
+ streamer=streamer,
256
+ )
257
+
258
+ # Start generation in a separate thread
259
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
260
+ thread.start()
261
+
262
+ # Stream the output
263
+ generated_text = ""
264
+ token_count = 0
265
+ start_time = time.time()
266
+
267
+ for text in streamer:
268
+ generated_text += text
269
+ token_count += 1
270
+ print(text, end="", flush=True)
271
+
272
+ thread.join()
273
+
274
+ end_time = time.time()
275
+ elapsed_time = end_time - start_time
276
+ tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
277
+
278
+ # Generated text is already only the new tokens (streamer skips prompt)
279
+ response = generated_text.strip()
280
+
281
+ # Remove trailing EOS if present
282
+ if response.endswith(tokenizer.eos_token):
283
+ response = response[:-len(tokenizer.eos_token)].rstrip()
284
+
285
+ # Extract code from markdown blocks if present
286
+ response = extract_code_from_response(response)
287
+
288
+ return response, token_count, elapsed_time, tokens_per_second
289
+ else:
290
+ # Non-streaming generation (original behavior)
291
+ with torch.no_grad():
292
+ outputs = model.generate(
293
+ **inputs,
294
+ max_new_tokens=max_new_tokens, # Use max_new_tokens for CodeLlama
295
+ temperature=temperature,
296
+ do_sample=temperature > 0, # Use sampling if temperature > 0
297
+ top_p=0.9 if temperature > 0 else None,
298
+ repetition_penalty=1.2, # Higher penalty to prevent repetition (was 1.1)
299
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
300
+ eos_token_id=tokenizer.eos_token_id,
301
+ )
302
+
303
+ # Decode only the newly generated tokens (after the input prompt)
304
+ input_length = inputs['input_ids'].shape[1]
305
+ generated_ids = outputs[0][input_length:]
306
+ response = tokenizer.decode(generated_ids, skip_special_tokens=False)
307
+
308
+ # Remove trailing EOS token if present
309
+ if response.endswith(tokenizer.eos_token):
310
+ response = response[:-len(tokenizer.eos_token)].rstrip()
311
+
312
+ # Extract code from markdown blocks if present
313
+ response = extract_code_from_response(response)
314
+
315
+ return response
316
+
317
+ def generate_with_ollama(prompt: str, model_name: str = OLLAMA_MODEL_NAME, url: str = DEFAULT_OLLAMA_URL, max_tokens: int = 800, temperature: float = 0.3):
318
+ """Generate text using Ollama API"""
319
+ # For CodeLlama, use prompt as-is or with minimal formatting
320
+ formatted_prompt = prompt
321
+
322
+ try:
323
+ response = requests.post(
324
+ f"{url}/api/generate",
325
+ json={
326
+ "model": model_name,
327
+ "prompt": formatted_prompt,
328
+ "stream": False,
329
+ "options": {
330
+ "temperature": temperature,
331
+ "num_predict": max_tokens,
332
+ }
333
+ },
334
+ timeout=120
335
+ )
336
+ response.raise_for_status()
337
+ result = response.json()
338
+ generated_text = result.get("response", "")
339
+
340
+ # Extract only the response part
341
+ response_text = generated_text.split("### Response:\n")[-1].strip()
342
+ return response_text
343
+ except requests.exceptions.ConnectionError:
344
+ print(f"Error: Could not connect to Ollama at {url}")
345
+ print("Make sure Ollama is running. Start it with: ollama serve")
346
+ sys.exit(1)
347
+ except requests.exceptions.RequestException as e:
348
+ print(f"Error calling Ollama API: {e}")
349
+ sys.exit(1)
350
+
351
+ def interactive_mode(use_ollama: bool, model_path: Optional[str] = None, base_model_path: Optional[str] = None, ollama_model: str = OLLAMA_MODEL_NAME, ollama_url: str = DEFAULT_OLLAMA_URL, use_quantization: Optional[bool] = None, merge_weights: bool = False):
352
+ """Run interactive inference session"""
353
+ model = None
354
+ tokenizer = None
355
+
356
+ if not use_ollama:
357
+ if not model_path:
358
+ print("Error: no model path provided for local mode")
359
+ sys.exit(1)
360
+ if not os.path.exists(model_path) and "/" not in model_path:
361
+ print(f"Error: Model path {model_path} does not exist")
362
+ sys.exit(1)
363
+ model, tokenizer = load_local_model(model_path, base_model_path, use_quantization, merge_weights)
364
+
365
+ print("\n" + "=" * 50)
366
+ print("CodeLlama 7B Interactive Inference")
367
+ print("Type 'quit' or 'exit' to stop")
368
+ print("=" * 50 + "\n")
369
+
370
+ while True:
371
+ try:
372
+ user_input = input("You: ").strip()
373
+
374
+ if user_input.lower() in ['quit', 'exit', 'q']:
375
+ print("Goodbye!")
376
+ break
377
+
378
+ if not user_input:
379
+ continue
380
+
381
+ print("\nAssistant: ", end="", flush=True)
382
+
383
+ if use_ollama:
384
+ start_time = time.time()
385
+ response = generate_with_ollama(user_input, ollama_model, ollama_url)
386
+ end_time = time.time()
387
+ inference_time = end_time - start_time
388
+ print(response)
389
+ print(f"\n⏱️ Inference time: {inference_time:.2f} seconds")
390
+ else:
391
+ # Use streaming for local model
392
+ response, token_count, elapsed_time, tokens_per_second = generate_with_local_model(
393
+ model, tokenizer, user_input, max_new_tokens=800, temperature=0.3, stream=True
394
+ )
395
+ print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec")
396
+
397
+ print()
398
+
399
+ except KeyboardInterrupt:
400
+ print("\n\nGoodbye!")
401
+ break
402
+ except Exception as e:
403
+ print(f"\nError: {e}")
404
+
405
+ def single_inference(prompt: str, use_ollama: bool, model_path: Optional[str] = None, base_model_path: Optional[str] = None, ollama_model: str = OLLAMA_MODEL_NAME, ollama_url: str = DEFAULT_OLLAMA_URL, use_quantization: Optional[bool] = None, merge_weights: bool = False, max_new_tokens: int = 800, temperature: float = 0.3):
406
+ """Run a single inference"""
407
+
408
+ if use_ollama:
409
+ start_time = time.time()
410
+ response = generate_with_ollama(prompt, ollama_model, ollama_url)
411
+ end_time = time.time()
412
+ inference_time = end_time - start_time
413
+ print(response)
414
+ print(f"\n⏱️ Inference time: {inference_time:.2f} seconds")
415
+ else:
416
+ if not model_path:
417
+ print("Error: no model path provided for local mode")
418
+ sys.exit(1)
419
+ if not os.path.exists(model_path) and "/" not in model_path:
420
+ print(f"Error: Model path {model_path} does not exist")
421
+ sys.exit(1)
422
+ model, tokenizer = load_local_model(model_path, base_model_path, use_quantization, merge_weights)
423
+
424
+ # Use streaming for local model
425
+ response, token_count, elapsed_time, tokens_per_second = generate_with_local_model(
426
+ model, tokenizer, prompt, max_new_tokens=max_new_tokens, temperature=temperature, stream=True
427
+ )
428
+ print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec")
429
+
430
+ def main():
431
+ parser = argparse.ArgumentParser(description="CodeLlama 7B Inference Script")
432
+ parser.add_argument(
433
+ "--mode",
434
+ choices=["local", "ollama"],
435
+ default="local",
436
+ help="Inference mode: local (fine-tuned model) or ollama (Ollama API)"
437
+ )
438
+ parser.add_argument(
439
+ "--model-path",
440
+ type=str,
441
+ default=DEFAULT_FINETUNED_MODEL,
442
+ help=f"Path to fine-tuned model (for local mode, default: {DEFAULT_FINETUNED_MODEL})"
443
+ )
444
+ parser.add_argument(
445
+ "--base-model-path",
446
+ type=str,
447
+ default=None,
448
+ help=f"Path to base model (if different from default: {DEFAULT_BASE_MODEL})"
449
+ )
450
+ parser.add_argument(
451
+ "--ollama-model",
452
+ type=str,
453
+ default=OLLAMA_MODEL_NAME,
454
+ help="Ollama model name (default: codellama:7b)"
455
+ )
456
+ parser.add_argument(
457
+ "--ollama-url",
458
+ type=str,
459
+ default=DEFAULT_OLLAMA_URL,
460
+ help="Ollama API URL (default: http://localhost:11434)"
461
+ )
462
+ parser.add_argument(
463
+ "--prompt",
464
+ type=str,
465
+ help="Single prompt to process (if not provided, runs in interactive mode)"
466
+ )
467
+ parser.add_argument(
468
+ "--no-quantization",
469
+ action="store_true",
470
+ help="Disable quantization for local models (requires more memory)"
471
+ )
472
+ parser.add_argument(
473
+ "--merge-weights",
474
+ action="store_true",
475
+ help="Merge LoRA weights into base model (slower loading but faster inference)"
476
+ )
477
+ parser.add_argument(
478
+ "--max-new-tokens",
479
+ type=int,
480
+ default=800,
481
+ help="Maximum number of new tokens to generate (default: 800)"
482
+ )
483
+ parser.add_argument(
484
+ "--temperature",
485
+ type=float,
486
+ default=0.3,
487
+ help="Temperature for generation (default: 0.3, lower = more deterministic)"
488
+ )
489
+
490
+ args = parser.parse_args()
491
+
492
+ use_ollama = args.mode == "ollama"
493
+ use_quantization = False if args.no_quantization else None # Auto-detect based on device unless disabled
494
+
495
+ if args.prompt:
496
+ if use_ollama:
497
+ start_time = time.time()
498
+ response = generate_with_ollama(args.prompt, args.ollama_model, args.ollama_url)
499
+ end_time = time.time()
500
+ inference_time = end_time - start_time
501
+ print(response)
502
+ print(f"\n⏱️ Inference time: {inference_time:.2f} seconds")
503
+ else:
504
+ if not args.model_path:
505
+ print("Error: no model path provided for local mode")
506
+ sys.exit(1)
507
+ if not os.path.exists(args.model_path) and "/" not in args.model_path:
508
+ print(f"Error: Model path {args.model_path} does not exist")
509
+ sys.exit(1)
510
+ model, tokenizer = load_local_model(args.model_path, args.base_model_path, use_quantization, args.merge_weights)
511
+
512
+ # Use streaming for local model
513
+ response, token_count, elapsed_time, tokens_per_second = generate_with_local_model(
514
+ model, tokenizer, args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, stream=True
515
+ )
516
+ print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec")
517
+ else:
518
+ interactive_mode(
519
+ use_ollama,
520
+ args.model_path if not use_ollama else None,
521
+ args.base_model_path,
522
+ args.ollama_model,
523
+ args.ollama_url,
524
+ use_quantization,
525
+ args.merge_weights
526
+ )
527
+
528
+ if __name__ == "__main__":
529
+ main()
530
+