|
|
|
|
|
""" |
|
|
Inference script for CodeLlama 7B |
|
|
Supports both Ollama and local fine-tuned models |
|
|
Updated for CodeLlama fine-tuned models |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import requests |
|
|
import json |
|
|
import time |
|
|
from typing import Optional, List |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
from threading import Thread |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
SCRIPT_DIR = Path(__file__).parent.parent.parent |
|
|
|
|
|
|
|
|
DEFAULT_OLLAMA_URL = "http://localhost:11434" |
|
|
OLLAMA_MODEL_NAME = "codellama:7b" |
|
|
DEFAULT_BASE_MODEL = str(SCRIPT_DIR / "models" / "base-models" / "CodeLlama-7B-Instruct") |
|
|
DEFAULT_FINETUNED_MODEL = str(SCRIPT_DIR / "training-outputs" / "codellama-fifo-v1") |
|
|
|
|
|
def extract_code_from_response(text: str) -> str: |
|
|
""" |
|
|
Extract Verilog code from markdown code blocks. |
|
|
Handles both ```verilog and generic ``` markers. |
|
|
""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
|
|
|
if '```verilog' in text: |
|
|
start = text.find('```verilog') + len('```verilog') |
|
|
end = text.find('```', start) |
|
|
if end != -1: |
|
|
extracted = text[start:end].strip() |
|
|
return extracted |
|
|
|
|
|
|
|
|
if '```' in text: |
|
|
|
|
|
start = text.find('```') |
|
|
if start != -1: |
|
|
|
|
|
start_marker = text.find('\n', start) |
|
|
if start_marker == -1: |
|
|
start_marker = start + 3 |
|
|
else: |
|
|
start_marker += 1 |
|
|
|
|
|
|
|
|
end = text.find('```', start_marker) |
|
|
if end != -1: |
|
|
extracted = text[start_marker:end].strip() |
|
|
return extracted |
|
|
|
|
|
|
|
|
return text.strip() |
|
|
|
|
|
def get_device_info(): |
|
|
"""Detect and return available compute device""" |
|
|
device_info = { |
|
|
"device": "cpu", |
|
|
"device_type": "cpu", |
|
|
"use_quantization": False, |
|
|
"dtype": torch.float32 |
|
|
} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device_info["device"] = "cuda" |
|
|
device_info["device_type"] = "cuda" |
|
|
device_info["use_quantization"] = True |
|
|
device_info["dtype"] = torch.float16 |
|
|
device_info["device_count"] = torch.cuda.device_count() |
|
|
device_info["device_name"] = torch.cuda.get_device_name(0) |
|
|
if device_info["device_count"] > 1: |
|
|
print(f"✓ {device_info['device_count']} CUDA GPUs detected:") |
|
|
for i in range(device_info["device_count"]): |
|
|
print(f" GPU {i}: {torch.cuda.get_device_name(i)}") |
|
|
print(f" Model will be automatically distributed across all GPUs") |
|
|
else: |
|
|
print(f"✓ CUDA GPU detected: {device_info['device_name']}") |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
device_info["device"] = "mps" |
|
|
device_info["device_type"] = "mps" |
|
|
device_info["use_quantization"] = False |
|
|
device_info["dtype"] = torch.float16 |
|
|
print("✓ Apple Silicon GPU (MPS) detected") |
|
|
else: |
|
|
print("⚠ No GPU detected, using CPU (inference will be slow)") |
|
|
device_info["dtype"] = torch.float32 |
|
|
|
|
|
return device_info |
|
|
|
|
|
def load_local_model(model_path: str, base_model_path: Optional[str] = None, use_quantization: Optional[bool] = None, merge_weights: bool = False): |
|
|
"""Load a fine-tuned CodeLlama model from local path""" |
|
|
device_info = get_device_info() |
|
|
print(f"\nLoading model from: {model_path}") |
|
|
|
|
|
|
|
|
if use_quantization is None: |
|
|
use_quantization = device_info["use_quantization"] |
|
|
|
|
|
|
|
|
tokenizer_path = model_path |
|
|
if not os.path.exists(os.path.join(model_path, "tokenizer_config.json")): |
|
|
if base_model_path and os.path.exists(base_model_path): |
|
|
tokenizer_path = base_model_path |
|
|
else: |
|
|
tokenizer_path = DEFAULT_BASE_MODEL |
|
|
|
|
|
print(f"Loading tokenizer from: {tokenizer_path}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
adapter_config_path = os.path.join(model_path, "adapter_config.json") |
|
|
is_lora = os.path.exists(adapter_config_path) |
|
|
|
|
|
|
|
|
def get_model_kwargs(quantize=False): |
|
|
kwargs = {"trust_remote_code": True} |
|
|
if quantize and device_info["device_type"] == "cuda": |
|
|
kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
kwargs["device_map"] = "auto" |
|
|
else: |
|
|
kwargs["torch_dtype"] = device_info["dtype"] |
|
|
if device_info["device_type"] == "mps": |
|
|
kwargs["device_map"] = "auto" |
|
|
elif device_info["device_type"] == "cuda": |
|
|
kwargs["device_map"] = "auto" |
|
|
else: |
|
|
kwargs["device_map"] = "cpu" |
|
|
return kwargs |
|
|
|
|
|
if is_lora: |
|
|
|
|
|
if base_model_path and os.path.exists(base_model_path): |
|
|
base_model_name = base_model_path |
|
|
print(f"Loading base model from specified path: {base_model_name}") |
|
|
elif os.path.exists(DEFAULT_BASE_MODEL): |
|
|
base_model_name = DEFAULT_BASE_MODEL |
|
|
print(f"Loading base model from default path: {base_model_name}") |
|
|
else: |
|
|
|
|
|
config_path = os.path.join(model_path, "training_config.json") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
base_model_name = config.get("base_model", "codellama/CodeLlama-7b-Instruct-hf") |
|
|
print(f"Loading base model from training config: {base_model_name}") |
|
|
else: |
|
|
base_model_name = "codellama/CodeLlama-7b-Instruct-hf" |
|
|
print(f"Loading base model from HuggingFace: {base_model_name}") |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_name, |
|
|
local_files_only=os.path.exists(base_model_name) and not base_model_name.startswith("codellama/"), |
|
|
**get_model_kwargs(use_quantization) |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading LoRA adapter...") |
|
|
model = PeftModel.from_pretrained(base_model, model_path) |
|
|
|
|
|
if merge_weights: |
|
|
print("Merging LoRA weights into base model...") |
|
|
model = model.merge_and_unload() |
|
|
else: |
|
|
print("Using LoRA adapter (weights not merged - faster loading)") |
|
|
else: |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
**get_model_kwargs(use_quantization) |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if device_info["device_type"] == "cuda" and device_info.get("device_count", 1) > 1: |
|
|
print(f"\nMulti-GPU Model Distribution:") |
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, 'weight') and module.weight is not None: |
|
|
device = next(module.parameters()).device |
|
|
if device.type == 'cuda': |
|
|
print(f" {name[:50]:<50} -> GPU {device.index}") |
|
|
break |
|
|
print(f" (Model automatically split across {device_info['device_count']} GPUs)") |
|
|
else: |
|
|
print(f"✅ Model loaded successfully on {device_info['device']}!") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
def generate_with_local_model(model, tokenizer, prompt: str, max_new_tokens: int = 800, temperature: float = 0.3, stream: bool = False, use_chat_template: bool = True): |
|
|
"""Generate text using local CodeLlama model""" |
|
|
|
|
|
if use_chat_template and ("[INST]" not in prompt and "</s>" not in prompt): |
|
|
|
|
|
|
|
|
|
|
|
parts = prompt.split("\n\n", 1) |
|
|
if len(parts) == 2: |
|
|
system_message = parts[0].strip() |
|
|
user_message = parts[1].strip() |
|
|
else: |
|
|
|
|
|
system_message = "You are Elinnos RTL Code Generator v1.0, a specialized Verilog/SystemVerilog code generation agent. Your role: Generate clean, synthesizable RTL code for hardware design tasks. Output ONLY functional RTL code with no $display, assertions, comments, or debug statements." |
|
|
user_message = prompt |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_message} |
|
|
] |
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
else: |
|
|
|
|
|
formatted_prompt = prompt |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=1536).to(model.device) |
|
|
|
|
|
if stream: |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=temperature > 0, |
|
|
top_p=0.9 if temperature > 0 else None, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
streamer=streamer, |
|
|
) |
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
generated_text = "" |
|
|
token_count = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
for text in streamer: |
|
|
generated_text += text |
|
|
token_count += 1 |
|
|
print(text, end="", flush=True) |
|
|
|
|
|
thread.join() |
|
|
|
|
|
end_time = time.time() |
|
|
elapsed_time = end_time - start_time |
|
|
tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0 |
|
|
|
|
|
|
|
|
response = generated_text.strip() |
|
|
|
|
|
|
|
|
if response.endswith(tokenizer.eos_token): |
|
|
response = response[:-len(tokenizer.eos_token)].rstrip() |
|
|
|
|
|
|
|
|
response = extract_code_from_response(response) |
|
|
|
|
|
return response, token_count, elapsed_time, tokens_per_second |
|
|
else: |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=temperature > 0, |
|
|
top_p=0.9 if temperature > 0 else None, |
|
|
repetition_penalty=1.2, |
|
|
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
|
generated_ids = outputs[0][input_length:] |
|
|
response = tokenizer.decode(generated_ids, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if response.endswith(tokenizer.eos_token): |
|
|
response = response[:-len(tokenizer.eos_token)].rstrip() |
|
|
|
|
|
|
|
|
response = extract_code_from_response(response) |
|
|
|
|
|
return response |
|
|
|
|
|
def generate_with_ollama(prompt: str, model_name: str = OLLAMA_MODEL_NAME, url: str = DEFAULT_OLLAMA_URL, max_tokens: int = 800, temperature: float = 0.3): |
|
|
"""Generate text using Ollama API""" |
|
|
|
|
|
formatted_prompt = prompt |
|
|
|
|
|
try: |
|
|
response = requests.post( |
|
|
f"{url}/api/generate", |
|
|
json={ |
|
|
"model": model_name, |
|
|
"prompt": formatted_prompt, |
|
|
"stream": False, |
|
|
"options": { |
|
|
"temperature": temperature, |
|
|
"num_predict": max_tokens, |
|
|
} |
|
|
}, |
|
|
timeout=120 |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
generated_text = result.get("response", "") |
|
|
|
|
|
|
|
|
response_text = generated_text.split("### Response:\n")[-1].strip() |
|
|
return response_text |
|
|
except requests.exceptions.ConnectionError: |
|
|
print(f"Error: Could not connect to Ollama at {url}") |
|
|
print("Make sure Ollama is running. Start it with: ollama serve") |
|
|
sys.exit(1) |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error calling Ollama API: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
def interactive_mode(use_ollama: bool, model_path: Optional[str] = None, base_model_path: Optional[str] = None, ollama_model: str = OLLAMA_MODEL_NAME, ollama_url: str = DEFAULT_OLLAMA_URL, use_quantization: Optional[bool] = None, merge_weights: bool = False): |
|
|
"""Run interactive inference session""" |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
if not use_ollama: |
|
|
if not model_path: |
|
|
print("Error: no model path provided for local mode") |
|
|
sys.exit(1) |
|
|
if not os.path.exists(model_path) and "/" not in model_path: |
|
|
print(f"Error: Model path {model_path} does not exist") |
|
|
sys.exit(1) |
|
|
model, tokenizer = load_local_model(model_path, base_model_path, use_quantization, merge_weights) |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("CodeLlama 7B Interactive Inference") |
|
|
print("Type 'quit' or 'exit' to stop") |
|
|
print("=" * 50 + "\n") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("You: ").strip() |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("Goodbye!") |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
print("\nAssistant: ", end="", flush=True) |
|
|
|
|
|
if use_ollama: |
|
|
start_time = time.time() |
|
|
response = generate_with_ollama(user_input, ollama_model, ollama_url) |
|
|
end_time = time.time() |
|
|
inference_time = end_time - start_time |
|
|
print(response) |
|
|
print(f"\n⏱️ Inference time: {inference_time:.2f} seconds") |
|
|
else: |
|
|
|
|
|
response, token_count, elapsed_time, tokens_per_second = generate_with_local_model( |
|
|
model, tokenizer, user_input, max_new_tokens=800, temperature=0.3, stream=True |
|
|
) |
|
|
print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec") |
|
|
|
|
|
print() |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nGoodbye!") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"\nError: {e}") |
|
|
|
|
|
def single_inference(prompt: str, use_ollama: bool, model_path: Optional[str] = None, base_model_path: Optional[str] = None, ollama_model: str = OLLAMA_MODEL_NAME, ollama_url: str = DEFAULT_OLLAMA_URL, use_quantization: Optional[bool] = None, merge_weights: bool = False, max_new_tokens: int = 800, temperature: float = 0.3): |
|
|
"""Run a single inference""" |
|
|
|
|
|
if use_ollama: |
|
|
start_time = time.time() |
|
|
response = generate_with_ollama(prompt, ollama_model, ollama_url) |
|
|
end_time = time.time() |
|
|
inference_time = end_time - start_time |
|
|
print(response) |
|
|
print(f"\n⏱️ Inference time: {inference_time:.2f} seconds") |
|
|
else: |
|
|
if not model_path: |
|
|
print("Error: no model path provided for local mode") |
|
|
sys.exit(1) |
|
|
if not os.path.exists(model_path) and "/" not in model_path: |
|
|
print(f"Error: Model path {model_path} does not exist") |
|
|
sys.exit(1) |
|
|
model, tokenizer = load_local_model(model_path, base_model_path, use_quantization, merge_weights) |
|
|
|
|
|
|
|
|
response, token_count, elapsed_time, tokens_per_second = generate_with_local_model( |
|
|
model, tokenizer, prompt, max_new_tokens=max_new_tokens, temperature=temperature, stream=True |
|
|
) |
|
|
print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec") |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="CodeLlama 7B Inference Script") |
|
|
parser.add_argument( |
|
|
"--mode", |
|
|
choices=["local", "ollama"], |
|
|
default="local", |
|
|
help="Inference mode: local (fine-tuned model) or ollama (Ollama API)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
default=DEFAULT_FINETUNED_MODEL, |
|
|
help=f"Path to fine-tuned model (for local mode, default: {DEFAULT_FINETUNED_MODEL})" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base-model-path", |
|
|
type=str, |
|
|
default=None, |
|
|
help=f"Path to base model (if different from default: {DEFAULT_BASE_MODEL})" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ollama-model", |
|
|
type=str, |
|
|
default=OLLAMA_MODEL_NAME, |
|
|
help="Ollama model name (default: codellama:7b)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ollama-url", |
|
|
type=str, |
|
|
default=DEFAULT_OLLAMA_URL, |
|
|
help="Ollama API URL (default: http://localhost:11434)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
help="Single prompt to process (if not provided, runs in interactive mode)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no-quantization", |
|
|
action="store_true", |
|
|
help="Disable quantization for local models (requires more memory)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--merge-weights", |
|
|
action="store_true", |
|
|
help="Merge LoRA weights into base model (slower loading but faster inference)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-new-tokens", |
|
|
type=int, |
|
|
default=800, |
|
|
help="Maximum number of new tokens to generate (default: 800)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", |
|
|
type=float, |
|
|
default=0.3, |
|
|
help="Temperature for generation (default: 0.3, lower = more deterministic)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
use_ollama = args.mode == "ollama" |
|
|
use_quantization = False if args.no_quantization else None |
|
|
|
|
|
if args.prompt: |
|
|
if use_ollama: |
|
|
start_time = time.time() |
|
|
response = generate_with_ollama(args.prompt, args.ollama_model, args.ollama_url) |
|
|
end_time = time.time() |
|
|
inference_time = end_time - start_time |
|
|
print(response) |
|
|
print(f"\n⏱️ Inference time: {inference_time:.2f} seconds") |
|
|
else: |
|
|
if not args.model_path: |
|
|
print("Error: no model path provided for local mode") |
|
|
sys.exit(1) |
|
|
if not os.path.exists(args.model_path) and "/" not in args.model_path: |
|
|
print(f"Error: Model path {args.model_path} does not exist") |
|
|
sys.exit(1) |
|
|
model, tokenizer = load_local_model(args.model_path, args.base_model_path, use_quantization, args.merge_weights) |
|
|
|
|
|
|
|
|
response, token_count, elapsed_time, tokens_per_second = generate_with_local_model( |
|
|
model, tokenizer, args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, stream=True |
|
|
) |
|
|
print(f"\n\n⏱️ Generation time: {elapsed_time:.2f}s | Tokens: {token_count} | Speed: {tokens_per_second:.2f} tokens/sec") |
|
|
else: |
|
|
interactive_mode( |
|
|
use_ollama, |
|
|
args.model_path if not use_ollama else None, |
|
|
args.base_model_path, |
|
|
args.ollama_model, |
|
|
args.ollama_url, |
|
|
use_quantization, |
|
|
args.merge_weights |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|