|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
def load_model(model_path="final_model_continue"): |
|
|
"""Load the fine-tuned model""" |
|
|
print("🔧 Loading model...") |
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
"meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print("✅ Model loading completed!") |
|
|
return model, tokenizer |
|
|
|
|
|
def generate_response(model, tokenizer, prompt, max_length=200): |
|
|
"""Generate financial advice response""" |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_length, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return response[len(prompt):] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
prompt = """### Instruction: |
|
|
Please provide investment advice for investors regarding technology stocks. |
|
|
|
|
|
### Input: |
|
|
A technology company's revenue grew 20% this quarter, but profit margin decreased by 5%, mainly due to increased R&D investment. The company has major breakthroughs in AI. |
|
|
|
|
|
### Response:""" |
|
|
|
|
|
|
|
|
advice = generate_response(model, tokenizer, prompt) |
|
|
print("🤖 AI Investment Advice:") |
|
|
print(advice) |
|
|
|