Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -102,30 +102,27 @@ DEMO_DATA = [
|
|
| 102 |
# ==========================================
|
| 103 |
# PART 1: 剧本生成模型 (ScriptAgent)
|
| 104 |
# ==========================================
|
| 105 |
-
from swift.llm import
|
|
|
|
| 106 |
import torch
|
| 107 |
|
| 108 |
# 全局变量
|
| 109 |
MODEL_NAME = "XD-MU/ScriptAgent"
|
| 110 |
LOCAL_MODEL_PATH = "./downloaded_models/ScriptAgent"
|
| 111 |
-
|
| 112 |
-
model = None # 模型对象
|
| 113 |
-
tokenizer = None # 分词器对象
|
| 114 |
-
template = None # 模板对象
|
| 115 |
|
| 116 |
# 确保目录存在
|
| 117 |
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
|
| 118 |
-
os.makedirs(OFFLOAD_FOLDER, exist_ok=True)
|
| 119 |
|
| 120 |
def load_llm_model():
|
| 121 |
-
"""使用
|
| 122 |
-
global
|
| 123 |
-
if
|
| 124 |
return
|
| 125 |
|
| 126 |
try:
|
| 127 |
# 1. 检查本地是否已下载模型
|
| 128 |
-
if not os.path.exists(LOCAL_MODEL_PATH):
|
| 129 |
print(f"正在从 HuggingFace 下载模型到 {LOCAL_MODEL_PATH}...")
|
| 130 |
snapshot_download(
|
| 131 |
repo_id=MODEL_NAME,
|
|
@@ -137,28 +134,21 @@ def load_llm_model():
|
|
| 137 |
else:
|
| 138 |
print(f"✅ 模型已存在: {LOCAL_MODEL_PATH}")
|
| 139 |
|
| 140 |
-
# 2. 使用
|
| 141 |
-
print("正在使用
|
| 142 |
|
| 143 |
-
# 🔥 关键修改:使用
|
| 144 |
-
|
| 145 |
model_id_or_path=LOCAL_MODEL_PATH,
|
| 146 |
-
torch_dtype=torch.float16,
|
|
|
|
|
|
|
| 147 |
model_kwargs={
|
| 148 |
-
'
|
| 149 |
-
|
| 150 |
-
'offload_folder': OFFLOAD_FOLDER, # 内存溢出卸载到磁盘
|
| 151 |
-
},
|
| 152 |
-
max_model_len=4096, # 限制上下文长度
|
| 153 |
)
|
| 154 |
|
| 155 |
-
|
| 156 |
-
model.eval()
|
| 157 |
-
|
| 158 |
-
# 获取模板
|
| 159 |
-
template = get_template(tokenizer=tokenizer, model=model)
|
| 160 |
-
|
| 161 |
-
print("✅ SWIFT 模型加载完成(已启用内存优化)")
|
| 162 |
|
| 163 |
except Exception as e:
|
| 164 |
print(f"❌ 模型加载失败: {e}")
|
|
@@ -166,12 +156,12 @@ def load_llm_model():
|
|
| 166 |
traceback.print_exc()
|
| 167 |
|
| 168 |
def chat_with_scriptagent(user_input: str):
|
| 169 |
-
"""使用
|
| 170 |
-
global
|
| 171 |
|
| 172 |
-
if
|
| 173 |
load_llm_model()
|
| 174 |
-
if
|
| 175 |
return "❌ 模型加载失败,请检查后台日志。"
|
| 176 |
|
| 177 |
user_input = user_input.strip()
|
|
@@ -179,23 +169,33 @@ def chat_with_scriptagent(user_input: str):
|
|
| 179 |
return "请输入内容"
|
| 180 |
|
| 181 |
try:
|
| 182 |
-
print("🤖 正在使用
|
| 183 |
|
| 184 |
-
# 🔥 使用
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
)
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
print(f"✅ 生成结果长度: {len(response)} 字符")
|
|
|
|
| 199 |
return response if response else "⚠️ 生成为空,请重试"
|
| 200 |
|
| 201 |
except Exception as e:
|
|
|
|
| 102 |
# ==========================================
|
| 103 |
# PART 1: 剧本生成模型 (ScriptAgent)
|
| 104 |
# ==========================================
|
| 105 |
+
from swift.llm import PtEngine, RequestConfig, InferRequest
|
| 106 |
+
from swift.plugin import InferStats
|
| 107 |
import torch
|
| 108 |
|
| 109 |
# 全局变量
|
| 110 |
MODEL_NAME = "XD-MU/ScriptAgent"
|
| 111 |
LOCAL_MODEL_PATH = "./downloaded_models/ScriptAgent"
|
| 112 |
+
engine = None # InferEngine 对象
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# 确保目录存在
|
| 115 |
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
|
|
|
|
| 116 |
|
| 117 |
def load_llm_model():
|
| 118 |
+
"""使用 ms-swift 的 PtEngine 加载 ScriptAgent 模型"""
|
| 119 |
+
global engine
|
| 120 |
+
if engine is not None:
|
| 121 |
return
|
| 122 |
|
| 123 |
try:
|
| 124 |
# 1. 检查本地是否已下载模型
|
| 125 |
+
if not os.path.exists(os.path.join(LOCAL_MODEL_PATH, "config.json")):
|
| 126 |
print(f"正在从 HuggingFace 下载模型到 {LOCAL_MODEL_PATH}...")
|
| 127 |
snapshot_download(
|
| 128 |
repo_id=MODEL_NAME,
|
|
|
|
| 134 |
else:
|
| 135 |
print(f"✅ 模型已存在: {LOCAL_MODEL_PATH}")
|
| 136 |
|
| 137 |
+
# 2. 使用 ms-swift 的 PtEngine 加载模型
|
| 138 |
+
print("正在使用 ms-swift PtEngine 加载模型...")
|
| 139 |
|
| 140 |
+
# 🔥 关键修改:使用 PtEngine
|
| 141 |
+
engine = PtEngine(
|
| 142 |
model_id_or_path=LOCAL_MODEL_PATH,
|
| 143 |
+
torch_dtype=torch.float16, # 半精度
|
| 144 |
+
max_batch_size=1, # 批处理大小
|
| 145 |
+
device_map='cpu', # CPU设备
|
| 146 |
model_kwargs={
|
| 147 |
+
'low_cpu_mem_usage': True, # 低内存模式
|
| 148 |
+
}
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
+
print("✅ ms-swift PtEngine 加载完成")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
print(f"❌ 模型加载失败: {e}")
|
|
|
|
| 156 |
traceback.print_exc()
|
| 157 |
|
| 158 |
def chat_with_scriptagent(user_input: str):
|
| 159 |
+
"""使用 ms-swift InferEngine 与 ScriptAgent 对话生成剧本"""
|
| 160 |
+
global engine
|
| 161 |
|
| 162 |
+
if engine is None:
|
| 163 |
load_llm_model()
|
| 164 |
+
if engine is None:
|
| 165 |
return "❌ 模型加载失败,请检查后台日志。"
|
| 166 |
|
| 167 |
user_input = user_input.strip()
|
|
|
|
| 169 |
return "请输入内容"
|
| 170 |
|
| 171 |
try:
|
| 172 |
+
print("🤖 正在使用 ms-swift InferEngine 推理剧本...")
|
| 173 |
|
| 174 |
+
# 🔥 使用 ms-swift 的推理方式
|
| 175 |
+
# 1. 构建消息格式
|
| 176 |
+
messages = [{'role': 'user', 'content': user_input}]
|
| 177 |
+
infer_request = InferRequest(messages=messages)
|
| 178 |
+
|
| 179 |
+
# 2. 配置请求参数
|
| 180 |
+
request_config = RequestConfig(
|
| 181 |
+
max_tokens=4096, # 最大生成token数
|
| 182 |
+
temperature=0.7, # 温度参数
|
| 183 |
+
top_p=0.9, # top_p 采样
|
| 184 |
+
repetition_penalty=1.1, # 重复惩罚
|
| 185 |
+
stream=False, # 不使用流式输出
|
| 186 |
)
|
| 187 |
|
| 188 |
+
# 3. 执行推理
|
| 189 |
+
metric = InferStats()
|
| 190 |
+
resp_list = engine.infer([infer_request], request_config, metrics=[metric])
|
| 191 |
+
|
| 192 |
+
# 4. 提取结果
|
| 193 |
+
response = resp_list[0].choices[0].message.content
|
| 194 |
+
|
| 195 |
+
# 5. 打印性能指标(可选)
|
| 196 |
+
print(f"✅ 生成完成 | 指标: {metric.compute()}")
|
| 197 |
print(f"✅ 生成结果长度: {len(response)} 字符")
|
| 198 |
+
|
| 199 |
return response if response else "⚠️ 生成为空,请重试"
|
| 200 |
|
| 201 |
except Exception as e:
|