jiang1002 commited on
Commit
561bc7d
·
verified ·
1 Parent(s): 8e63062

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -6,23 +6,26 @@ import torch
6
 
7
  app = FastAPI()
8
 
9
- # 1. 加载模型
10
  MODEL_PATH = "jiang1002/chatglm-6b-adgen"
11
- print("🚀 正在加载模型...")
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
14
- # 强制 CPU 运行,.float() 是为了防止精度溢出导致内存崩掉
 
 
15
  model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float()
16
  model.eval()
17
- print("✅ 模型就绪!")
18
 
19
- # 2. 别人调用的 API 接口 (POST /generate)
 
 
20
  @app.post("/generate")
21
  async def generate(request: Request):
22
  try:
23
  data = await request.json()
24
  prompt = data.get("text", "")
25
- # 直接用本地模型,不求人
26
  response, _ = model.chat(tokenizer, prompt, history=[])
27
  return {"success": True, "result": response}
28
  except Exception as e:
@@ -33,7 +36,7 @@ def chat_func(msg, hist):
33
  res, _ = model.chat(tokenizer, msg, history=hist)
34
  return res
35
 
36
- demo = gr.ChatInterface(fn=chat_func, title="ChatGLM API Server")
37
  app = gr.mount_gradio_app(app, demo, path="/")
38
 
39
  if __name__ == "__main__":
 
6
 
7
  app = FastAPI()
8
 
9
+ # 1. 加载模型逻辑
10
  MODEL_PATH = "jiang1002/chatglm-6b-adgen"
11
+ print("🚀 正在加载模型到 CPU (这需要大约 15GB 内存,请确保 Space 没爆内存)...")
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
14
+
15
+ # 【核心修改】强制使用 .float() 并在 CPU 上运行
16
+ # 不要写 .cuda(),不要写 device='cuda'
17
  model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float()
18
  model.eval()
 
19
 
20
+ print("✅ CPU 模式加载成功!API 已就绪。")
21
+
22
+ # 2. 别人调用的 API 接口
23
  @app.post("/generate")
24
  async def generate(request: Request):
25
  try:
26
  data = await request.json()
27
  prompt = data.get("text", "")
28
+ # 直接在 CPU 上推理
29
  response, _ = model.chat(tokenizer, prompt, history=[])
30
  return {"success": True, "result": response}
31
  except Exception as e:
 
36
  res, _ = model.chat(tokenizer, msg, history=hist)
37
  return res
38
 
39
+ demo = gr.ChatInterface(fn=chat_func, title="ChatGLM CPU API Server")
40
  app = gr.mount_gradio_app(app, demo, path="/")
41
 
42
  if __name__ == "__main__":