TimberGu commited on
Commit
2777377
·
verified ·
1 Parent(s): 4d951b3

Upload use_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. use_model.py +70 -0
use_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Financial LLaMA Model Usage Script
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+
6
+ def load_model(model_path="final_model_continue"):
7
+ """Load the fine-tuned model"""
8
+ print("🔧 Loading model...")
9
+
10
+ # 4bit quantization configuration
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_use_double_quant=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.bfloat16
16
+ )
17
+
18
+ # Load base model
19
+ base_model = AutoModelForCausalLM.from_pretrained(
20
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
21
+ quantization_config=bnb_config,
22
+ device_map="auto",
23
+ torch_dtype=torch.bfloat16,
24
+ )
25
+
26
+ # Load LoRA adapter
27
+ model = PeftModel.from_pretrained(base_model, model_path)
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+
30
+ if tokenizer.pad_token is None:
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+
33
+ print("✅ Model loading completed!")
34
+ return model, tokenizer
35
+
36
+ def generate_response(model, tokenizer, prompt, max_length=200):
37
+ """Generate financial advice response"""
38
+ inputs = tokenizer(prompt, return_tensors="pt")
39
+
40
+ with torch.no_grad():
41
+ outputs = model.generate(
42
+ **inputs,
43
+ max_new_tokens=max_length,
44
+ do_sample=True,
45
+ temperature=0.7,
46
+ top_p=0.9,
47
+ pad_token_id=tokenizer.eos_token_id
48
+ )
49
+
50
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return response[len(prompt):]
52
+
53
+ # Usage example
54
+ if __name__ == "__main__":
55
+ # Load model
56
+ model, tokenizer = load_model()
57
+
58
+ # Test prompt
59
+ prompt = """### Instruction:
60
+ Please provide investment advice for investors regarding technology stocks.
61
+
62
+ ### Input:
63
+ A technology company's revenue grew 20% this quarter, but profit margin decreased by 5%, mainly due to increased R&D investment. The company has major breakthroughs in AI.
64
+
65
+ ### Response:"""
66
+
67
+ # Generate advice
68
+ advice = generate_response(model, tokenizer, prompt)
69
+ print("🤖 AI Investment Advice:")
70
+ print(advice)