jolchmo commited on
Commit
13e78d2
·
1 Parent(s): f43468a
Files changed (2) hide show
  1. README.md +3 -2
  2. app.py +12 -8
README.md CHANGED
@@ -24,9 +24,10 @@ hf_oauth: true
24
 
25
  ## 功能特性
26
 
27
- - 💬 实时对话:支持多轮对话,保持上下文
28
- - 🧠 金融专业:基于Llama 3-8B微调的金融领域模型
29
  - 🚀 GPU加速:使用Hugging Face Spaces的GPU支持
 
30
 
31
  ## 使用说明
32
 
 
24
 
25
  ## 功能特性
26
 
27
+ - 💬 智能对话:基于金融领域微调的对话系统
28
+ - 🧠 金融专业:使用Llama 3-8B + LoRA适配器
29
  - 🚀 GPU加速:使用Hugging Face Spaces的GPU支持
30
+ - 💾 智能缓存:模型文件本地缓存,加速启动
31
 
32
  ## 使用说明
33
 
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
6
  import os
7
 
@@ -25,26 +25,30 @@ model_loaded = False
25
 
26
  try:
27
  print("\n[1/3] 加载tokenizer...")
28
- tokenizer = AutoTokenizer.from_pretrained(
29
  model_name,
30
  trust_remote_code=True,
31
- token=hf_token
32
  )
33
  tokenizer.pad_token = tokenizer.eos_token
34
  print("✓ Tokenizer加载成功")
35
 
36
  print("\n[2/3] 加载基础模型...")
37
- base_model = AutoModelForCausalLM.from_pretrained(
38
  model_name,
39
  torch_dtype=torch.float16,
40
  device_map="auto",
41
  trust_remote_code=True,
42
- token=hf_token
 
43
  )
44
- print("✓ 基础模型加载成功")
45
-
46
  print("\n[3/3] 加载LoRA适配器...")
47
- model = PeftModel.from_pretrained(base_model, adapter_name)
 
 
 
 
 
48
  model = model.eval()
49
  print("✓ LoRA适配器加载成功")
50
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from transformers import LlamaTokenizerFast, LlamaForCausalLM
5
  from peft import PeftModel
6
  import os
7
 
 
25
 
26
  try:
27
  print("\n[1/3] 加载tokenizer...")
28
+ tokenizer = LlamaTokenizerFast.from_pretrained(
29
  model_name,
30
  trust_remote_code=True,
31
+ token=hf_token,
32
  )
33
  tokenizer.pad_token = tokenizer.eos_token
34
  print("✓ Tokenizer加载成功")
35
 
36
  print("\n[2/3] 加载基础模型...")
37
+ base_model = LlamaForCausalLM.from_pretrained(
38
  model_name,
39
  torch_dtype=torch.float16,
40
  device_map="auto",
41
  trust_remote_code=True,
42
+ token=hf_token,
43
+ cache_dir=cache_dir
44
  )
 
 
45
  print("\n[3/3] 加载LoRA适配器...")
46
+ model = PeftModel.from_pretrained(base_model, adapter_name, cache_dir=cache_dir)
47
+ model = model.eval()
48
+
49
+ # 确保模型在正确的设备上
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ print(f"✓ LoRA适配器加载成功 (设备: {device})")_pretrained(base_model, adapter_name)
52
  model = model.eval()
53
  print("✓ LoRA适配器加载成功")
54