XD-MU commited on
Commit
74a7ccc
·
1 Parent(s): 2d1b51a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -102,30 +102,27 @@ DEMO_DATA = [
102
  # ==========================================
103
  # PART 1: 剧本生成模型 (ScriptAgent)
104
  # ==========================================
105
- from swift.llm import get_model_tokenizer, get_template, inference
 
106
  import torch
107
 
108
  # 全局变量
109
  MODEL_NAME = "XD-MU/ScriptAgent"
110
  LOCAL_MODEL_PATH = "./downloaded_models/ScriptAgent"
111
- OFFLOAD_FOLDER = "./offload"
112
- model = None # 模型对象
113
- tokenizer = None # 分词器对象
114
- template = None # 模板对象
115
 
116
  # 确保目录存在
117
  os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
118
- os.makedirs(OFFLOAD_FOLDER, exist_ok=True)
119
 
120
  def load_llm_model():
121
- """使用 SWIFT 加载 ScriptAgent 模型 - CPU优化版本"""
122
- global model, tokenizer, template
123
- if model is not None:
124
  return
125
 
126
  try:
127
  # 1. 检查本地是否已下载模型
128
- if not os.path.exists(LOCAL_MODEL_PATH):
129
  print(f"正在从 HuggingFace 下载模型到 {LOCAL_MODEL_PATH}...")
130
  snapshot_download(
131
  repo_id=MODEL_NAME,
@@ -137,28 +134,21 @@ def load_llm_model():
137
  else:
138
  print(f"✅ 模型已存在: {LOCAL_MODEL_PATH}")
139
 
140
- # 2. 使用 SWIFT 正确加载模型
141
- print("正在使用 SWIFT 加载模型(CPU + 半精度优化)...")
142
 
143
- # 🔥 关键修改:使用 get_model_tokenizer
144
- model, tokenizer = get_model_tokenizer(
145
  model_id_or_path=LOCAL_MODEL_PATH,
146
- torch_dtype=torch.float16, # 半精度
 
 
147
  model_kwargs={
148
- 'device_map': 'cpu', # CPU设备
149
- 'low_cpu_mem_usage': True, # 低内存模式
150
- 'offload_folder': OFFLOAD_FOLDER, # 内存溢出卸载到磁盘
151
- },
152
- max_model_len=4096, # 限制上下文长度
153
  )
154
 
155
- # 设置为评估模式
156
- model.eval()
157
-
158
- # 获取模板
159
- template = get_template(tokenizer=tokenizer, model=model)
160
-
161
- print("✅ SWIFT 模型加载完成(已启用内存优化)")
162
 
163
  except Exception as e:
164
  print(f"❌ 模型加载失败: {e}")
@@ -166,12 +156,12 @@ def load_llm_model():
166
  traceback.print_exc()
167
 
168
  def chat_with_scriptagent(user_input: str):
169
- """使用 SWIFT 与 ScriptAgent 对话生成剧本"""
170
- global model, tokenizer, template
171
 
172
- if model is None:
173
  load_llm_model()
174
- if model is None:
175
  return "❌ 模型加载失败,请检查后台日志。"
176
 
177
  user_input = user_input.strip()
@@ -179,23 +169,33 @@ def chat_with_scriptagent(user_input: str):
179
  return "请输入内容"
180
 
181
  try:
182
- print("🤖 正在使用 SWIFT 推理剧本...")
183
 
184
- # 🔥 使用 SWIFT 的 inference 函数
185
- response, _ = inference(
186
- model=model,
187
- tokenizer=tokenizer,
188
- template=template,
189
- query=user_input,
190
- max_new_tokens=4096, # 从8192降低到4096
191
- temperature=0.7,
192
- top_p=0.9,
193
- repetition_penalty=1.1,
194
- do_sample=True,
195
- num_beams=1, # 贪婪解码
196
  )
197
 
 
 
 
 
 
 
 
 
 
198
  print(f"✅ 生成结果长度: {len(response)} 字符")
 
199
  return response if response else "⚠️ 生成为空,请重试"
200
 
201
  except Exception as e:
 
102
  # ==========================================
103
  # PART 1: 剧本生成模型 (ScriptAgent)
104
  # ==========================================
105
+ from swift.llm import PtEngine, RequestConfig, InferRequest
106
+ from swift.plugin import InferStats
107
  import torch
108
 
109
  # 全局变量
110
  MODEL_NAME = "XD-MU/ScriptAgent"
111
  LOCAL_MODEL_PATH = "./downloaded_models/ScriptAgent"
112
+ engine = None # InferEngine 对象
 
 
 
113
 
114
  # 确保目录存在
115
  os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
 
116
 
117
  def load_llm_model():
118
+ """使用 ms-swift 的 PtEngine 加载 ScriptAgent 模型"""
119
+ global engine
120
+ if engine is not None:
121
  return
122
 
123
  try:
124
  # 1. 检查本地是否已下载模型
125
+ if not os.path.exists(os.path.join(LOCAL_MODEL_PATH, "config.json")):
126
  print(f"正在从 HuggingFace 下载模型到 {LOCAL_MODEL_PATH}...")
127
  snapshot_download(
128
  repo_id=MODEL_NAME,
 
134
  else:
135
  print(f"✅ 模型已存在: {LOCAL_MODEL_PATH}")
136
 
137
+ # 2. 使用 ms-swift 的 PtEngine 加载模型
138
+ print("正在使用 ms-swift PtEngine 加载模型...")
139
 
140
+ # 🔥 关键修改:使用 PtEngine
141
+ engine = PtEngine(
142
  model_id_or_path=LOCAL_MODEL_PATH,
143
+ torch_dtype=torch.float16, # 半精度
144
+ max_batch_size=1, # 批处理大小
145
+ device_map='cpu', # CPU设备
146
  model_kwargs={
147
+ 'low_cpu_mem_usage': True, # 低内存模式
148
+ }
 
 
 
149
  )
150
 
151
+ print("✅ ms-swift PtEngine 加载完成")
 
 
 
 
 
 
152
 
153
  except Exception as e:
154
  print(f"❌ 模型加载失败: {e}")
 
156
  traceback.print_exc()
157
 
158
  def chat_with_scriptagent(user_input: str):
159
+ """使用 ms-swift InferEngine 与 ScriptAgent 对话生成剧本"""
160
+ global engine
161
 
162
+ if engine is None:
163
  load_llm_model()
164
+ if engine is None:
165
  return "❌ 模型加载失败,请检查后台日志。"
166
 
167
  user_input = user_input.strip()
 
169
  return "请输入内容"
170
 
171
  try:
172
+ print("🤖 正在使用 ms-swift InferEngine 推理剧本...")
173
 
174
+ # 🔥 使用 ms-swift 的推理方式
175
+ # 1. 构建消息格式
176
+ messages = [{'role': 'user', 'content': user_input}]
177
+ infer_request = InferRequest(messages=messages)
178
+
179
+ # 2. 配置请求参数
180
+ request_config = RequestConfig(
181
+ max_tokens=4096, # 最大生成token数
182
+ temperature=0.7, # 温度参数
183
+ top_p=0.9, # top_p 采样
184
+ repetition_penalty=1.1, # 重复惩罚
185
+ stream=False, # 不使用流式输出
186
  )
187
 
188
+ # 3. 执行推理
189
+ metric = InferStats()
190
+ resp_list = engine.infer([infer_request], request_config, metrics=[metric])
191
+
192
+ # 4. 提取结果
193
+ response = resp_list[0].choices[0].message.content
194
+
195
+ # 5. 打印性能指标(可选)
196
+ print(f"✅ 生成完成 | 指标: {metric.compute()}")
197
  print(f"✅ 生成结果长度: {len(response)} 字符")
198
+
199
  return response if response else "⚠️ 生成为空,请重试"
200
 
201
  except Exception as e: