gyc12 commited on
Commit
0e4ad8e
·
verified ·
1 Parent(s): 505eb00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -43
app.py CHANGED
@@ -1,65 +1,134 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, LlamaForCausalLM
3
  import torch
 
 
 
 
4
 
5
- # 使用 UrbanGPT 模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  model_name = "bjdwh/UrbanGPT"
7
 
8
- # 加载模型和分词器
9
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
- model = LlamaForCausalLM.from_pretrained(
11
- model_name,
12
- torch_dtype=torch.float16,
13
- low_cpu_mem_usage=True,
14
- trust_remote_code=True
15
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_response(
18
- message,
19
- history: list[tuple[str, str]],
20
- max_tokens,
21
- temperature,
22
- top_p,
23
  ):
24
- # 格式化输入
25
- input_text = message
26
- if history:
27
- input_text = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in history]) + f"\nUser: {message}"
28
-
29
- # 编码输入
30
- inputs = tokenizer(input_text, return_tensors="pt", padding=True)
31
-
32
- # 生成回复
33
- with torch.no_grad():
34
- outputs = model.generate(
35
- inputs["input_ids"],
36
- max_length=max_tokens,
37
- temperature=temperature,
38
- top_p=top_p,
39
- num_return_sequences=1,
40
- pad_token_id=tokenizer.eos_token_id
41
  )
42
-
43
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
-
45
- # 如果有历史对话,需要提取最后的回复
46
- if history:
47
- response = response.split("Assistant: ")[-1].strip()
48
-
49
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # 创建 Gradio 界面
52
  demo = gr.ChatInterface(
53
  generate_response,
54
  additional_inputs=[
55
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="生成最大长度"),
56
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"),
 
 
 
 
 
 
 
 
 
 
 
 
57
  gr.Slider(
58
  minimum=0.1,
59
  maximum=1.0,
60
  value=0.95,
61
  step=0.05,
62
- label="Top-p (核采样)",
63
  ),
64
  ],
65
  title="UrbanGPT 聊天助手",
@@ -67,4 +136,12 @@ demo = gr.ChatInterface(
67
  )
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, LlamaForCausalLM
3
  import torch
4
+ import psutil
5
+ import gc
6
+ from typing import List, Tuple
7
+ import logging
8
 
9
+ # 配置日志
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def print_memory_usage():
14
+ """监控内存使用情况"""
15
+ process = psutil.Process()
16
+ cpu_mem = process.memory_info().rss / 1024 / 1024
17
+ gpu_mem = torch.cuda.memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0
18
+ logger.info(f"CPU Memory: {cpu_mem:.2f}MB, GPU Memory: {gpu_mem:.2f}MB")
19
+
20
+ def optimize_memory():
21
+ """优化内存使用"""
22
+ gc.collect()
23
+ if torch.cuda.is_available():
24
+ torch.cuda.empty_cache()
25
+ print_memory_usage()
26
+
27
+ # 模型配置
28
  model_name = "bjdwh/UrbanGPT"
29
 
30
+ try:
31
+ # 加载模型和分词器
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ model_name,
34
+ trust_remote_code=True
35
+ )
36
+
37
+ # 使用 8bit 量化加载模型
38
+ model = LlamaForCausalLM.from_pretrained(
39
+ model_name,
40
+ load_in_8bit=True, # 启用8bit量化
41
+ torch_dtype=torch.float16, # 使用半精度
42
+ low_cpu_mem_usage=True,
43
+ trust_remote_code=True,
44
+ device_map="auto" # 自动设备映射
45
+ )
46
+
47
+ # 启用梯度检查点
48
+ model.gradient_checkpointing_enable()
49
+
50
+ except Exception as e:
51
+ logger.error(f"模型加载失败: {str(e)}")
52
+ raise
53
 
54
  def generate_response(
55
+ message: str,
56
+ history: List[Tuple[str, str]],
57
+ max_tokens: int,
58
+ temperature: float,
59
+ top_p: float,
60
  ):
61
+ try:
62
+ optimize_memory() # 优化内存使用
63
+
64
+ # 格式化输入
65
+ input_text = message
66
+ if history:
67
+ input_text = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in history]) + f"\nUser: {message}"
68
+
69
+ # 编码输入
70
+ inputs = tokenizer(
71
+ input_text,
72
+ return_tensors="pt",
73
+ padding=True,
74
+ truncation=True,
75
+ max_length=2048 # 添加最大长度限制
 
 
76
  )
77
+
78
+ # 将输入移到GPU(如果可用)
79
+ if torch.cuda.is_available():
80
+ inputs = {k: v.cuda() for k, v in inputs.items()}
81
+
82
+ # 生成回复
83
+ with torch.no_grad():
84
+ outputs = model.generate(
85
+ inputs["input_ids"],
86
+ max_length=max_tokens,
87
+ temperature=temperature,
88
+ top_p=top_p,
89
+ num_return_sequences=1,
90
+ pad_token_id=tokenizer.eos_token_id,
91
+ do_sample=True, # 启用采样
92
+ repetition_penalty=1.2 # 添加重复惩罚
93
+ )
94
+
95
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
+
97
+ # 提取最后的回复
98
+ if history:
99
+ response = response.split("Assistant: ")[-1].strip()
100
+
101
+ optimize_memory() # 生成后再次优化内存
102
+ yield response
103
+
104
+ except Exception as e:
105
+ logger.error(f"生成回复时发生错误: {str(e)}")
106
+ yield f"抱歉,生成回复时发生错误: {str(e)}"
107
 
108
  # 创建 Gradio 界面
109
  demo = gr.ChatInterface(
110
  generate_response,
111
  additional_inputs=[
112
+ gr.Slider(
113
+ minimum=1,
114
+ maximum=2048,
115
+ value=512,
116
+ step=1,
117
+ label="生成最大长度"
118
+ ),
119
+ gr.Slider(
120
+ minimum=0.1,
121
+ maximum=4.0,
122
+ value=0.7,
123
+ step=0.1,
124
+ label="温度"
125
+ ),
126
  gr.Slider(
127
  minimum=0.1,
128
  maximum=1.0,
129
  value=0.95,
130
  step=0.05,
131
+ label="Top-p (核采样)"
132
  ),
133
  ],
134
  title="UrbanGPT 聊天助手",
 
136
  )
137
 
138
  if __name__ == "__main__":
139
+ # 启动前进行内存优化
140
+ optimize_memory()
141
+ # 添加自定义配置
142
+ demo.launch(
143
+ share=False,
144
+ debug=True,
145
+ server_name="0.0.0.0",
146
+ server_port=7860
147
+ )