gyc12 commited on
Commit
d3ef534
·
verified ·
1 Parent(s): 0e4ad8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -120
app.py CHANGED
@@ -1,134 +1,65 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, LlamaForCausalLM
3
  import torch
4
- import psutil
5
- import gc
6
- from typing import List, Tuple
7
- import logging
8
 
9
- # 配置日志
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- def print_memory_usage():
14
- """监控内存使用情况"""
15
- process = psutil.Process()
16
- cpu_mem = process.memory_info().rss / 1024 / 1024
17
- gpu_mem = torch.cuda.memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0
18
- logger.info(f"CPU Memory: {cpu_mem:.2f}MB, GPU Memory: {gpu_mem:.2f}MB")
19
-
20
- def optimize_memory():
21
- """优化内存使用"""
22
- gc.collect()
23
- if torch.cuda.is_available():
24
- torch.cuda.empty_cache()
25
- print_memory_usage()
26
-
27
- # 模型配置
28
  model_name = "bjdwh/UrbanGPT"
29
 
30
- try:
31
- # 加载模型和分词器
32
- tokenizer = AutoTokenizer.from_pretrained(
33
- model_name,
34
- trust_remote_code=True
35
- )
36
-
37
- # 使用 8bit 量化加载模型
38
- model = LlamaForCausalLM.from_pretrained(
39
- model_name,
40
- load_in_8bit=True, # 启用8bit量化
41
- torch_dtype=torch.float16, # 使用半精度
42
- low_cpu_mem_usage=True,
43
- trust_remote_code=True,
44
- device_map="auto" # 自动设备映射
45
- )
46
-
47
- # 启用梯度检查点
48
- model.gradient_checkpointing_enable()
49
-
50
- except Exception as e:
51
- logger.error(f"模型加载失败: {str(e)}")
52
- raise
53
 
54
  def generate_response(
55
- message: str,
56
- history: List[Tuple[str, str]],
57
- max_tokens: int,
58
- temperature: float,
59
- top_p: float,
60
  ):
61
- try:
62
- optimize_memory() # 优化内存使用
63
-
64
- # 格式化输入
65
- input_text = message
66
- if history:
67
- input_text = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in history]) + f"\nUser: {message}"
68
-
69
- # 编码输入
70
- inputs = tokenizer(
71
- input_text,
72
- return_tensors="pt",
73
- padding=True,
74
- truncation=True,
75
- max_length=2048 # 添加最大长度限制
 
 
76
  )
77
-
78
- # 将输入移到GPU(如果可用)
79
- if torch.cuda.is_available():
80
- inputs = {k: v.cuda() for k, v in inputs.items()}
81
-
82
- # 生成回复
83
- with torch.no_grad():
84
- outputs = model.generate(
85
- inputs["input_ids"],
86
- max_length=max_tokens,
87
- temperature=temperature,
88
- top_p=top_p,
89
- num_return_sequences=1,
90
- pad_token_id=tokenizer.eos_token_id,
91
- do_sample=True, # 启用采样
92
- repetition_penalty=1.2 # 添加重复惩罚
93
- )
94
-
95
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
-
97
- # 提取最后的回复
98
- if history:
99
- response = response.split("Assistant: ")[-1].strip()
100
-
101
- optimize_memory() # 生成后再次优化内存
102
- yield response
103
-
104
- except Exception as e:
105
- logger.error(f"生成回复时发生错误: {str(e)}")
106
- yield f"抱歉,生成回复时发生错误: {str(e)}"
107
 
108
  # 创建 Gradio 界面
109
  demo = gr.ChatInterface(
110
  generate_response,
111
  additional_inputs=[
112
- gr.Slider(
113
- minimum=1,
114
- maximum=2048,
115
- value=512,
116
- step=1,
117
- label="生成最大长度"
118
- ),
119
- gr.Slider(
120
- minimum=0.1,
121
- maximum=4.0,
122
- value=0.7,
123
- step=0.1,
124
- label="温度"
125
- ),
126
  gr.Slider(
127
  minimum=0.1,
128
  maximum=1.0,
129
  value=0.95,
130
  step=0.05,
131
- label="Top-p (核采样)"
132
  ),
133
  ],
134
  title="UrbanGPT 聊天助手",
@@ -136,12 +67,4 @@ demo = gr.ChatInterface(
136
  )
137
 
138
  if __name__ == "__main__":
139
- # 启动前进行内存优化
140
- optimize_memory()
141
- # 添加自定义配置
142
- demo.launch(
143
- share=False,
144
- debug=True,
145
- server_name="0.0.0.0",
146
- server_port=7860
147
- )
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, LlamaForCausalLM
3
  import torch
 
 
 
 
4
 
5
+ # 使用 UrbanGPT 模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  model_name = "bjdwh/UrbanGPT"
7
 
8
+ # 加载模型和分词器
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
+ model = LlamaForCausalLM.from_pretrained(
11
+ model_name,
12
+ torch_dtype=torch.float16,
13
+ low_cpu_mem_usage=True,
14
+ trust_remote_code=True
15
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_response(
18
+ message,
19
+ history: list[tuple[str, str]],
20
+ max_tokens,
21
+ temperature,
22
+ top_p,
23
  ):
24
+ # 格式化输入
25
+ input_text = message
26
+ if history:
27
+ input_text = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in history]) + f"\nUser: {message}"
28
+
29
+ # 编码输入
30
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True)
31
+
32
+ # 生成回复
33
+ with torch.no_grad():
34
+ outputs = model.generate(
35
+ inputs["input_ids"],
36
+ max_length=max_tokens,
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ num_return_sequences=1,
40
+ pad_token_id=tokenizer.eos_token_id
41
  )
42
+
43
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+
45
+ # 如果有历史对话,需要提取最后的回复
46
+ if history:
47
+ response = response.split("Assistant: ")[-1].strip()
48
+
49
+ yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # 创建 Gradio 界面
52
  demo = gr.ChatInterface(
53
  generate_response,
54
  additional_inputs=[
55
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="生成最大长度"),
56
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度"),
 
 
 
 
 
 
 
 
 
 
 
 
57
  gr.Slider(
58
  minimum=0.1,
59
  maximum=1.0,
60
  value=0.95,
61
  step=0.05,
62
+ label="Top-p (核采样)",
63
  ),
64
  ],
65
  title="UrbanGPT 聊天助手",
 
67
  )
68
 
69
  if __name__ == "__main__":
70
+ demo.launch()