zhman commited on
Commit
7afc078
·
1 Parent(s): 100fb2e

Add 8-bit quantization for faster inference

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -4,18 +4,24 @@ HuggingFace Spaces 推理应用
4
  """
5
 
6
  import gradio as gr
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import torch
9
 
10
  # 模型配置
11
  MODEL_NAME = "zhman/llama-SFT-GRPO"
12
 
 
 
 
 
 
 
13
  # 加载模型和分词器
14
  print("🔄 加载模型...")
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_NAME,
18
- torch_dtype=torch.bfloat16,
19
  device_map="auto"
20
  )
21
  print("✅ 模型加载完成!")
 
4
  """
5
 
6
  import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
  import torch
9
 
10
  # 模型配置
11
  MODEL_NAME = "zhman/llama-SFT-GRPO"
12
 
13
+ # 配置 8-bit 量化以提升速度和减少内存占用
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_8bit=True,
16
+ llm_int8_threshold=6.0
17
+ )
18
+
19
  # 加载模型和分词器
20
  print("🔄 加载模型...")
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
  model = AutoModelForCausalLM.from_pretrained(
23
  MODEL_NAME,
24
+ quantization_config=quantization_config, # 使用 8-bit 量化
25
  device_map="auto"
26
  )
27
  print("✅ 模型加载完成!")