jiang1002 commited on
Commit
c88a9bb
·
verified ·
1 Parent(s): 615868a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -33
app.py CHANGED
@@ -1,58 +1,98 @@
1
  import gradio as gr
2
  from fastapi import FastAPI, Request
3
  import uvicorn
4
- from transformers import AutoTokenizer, AutoModel
5
- import torch
6
  import os
 
 
7
 
8
- # 初始化接口
9
- app = FastAPI()
10
-
11
- # --- 模型配置 ---
12
- # 如果同学本地有模型文件,可以改成文件夹路径
13
- MODEL_PATH = "jiang1002/chatglm-6b-adgen"
14
 
15
- print("🚀 正在加载模型,请稍候...")
 
16
 
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
 
 
 
18
 
19
- # 自动检测设备:有显卡用显卡,没显卡用 CPU
20
- if torch.cuda.is_available():
21
- print("✨ 检测到 GPU,正在使用显卡加速...")
22
- model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).half().cuda()
23
- else:
24
- print("☁️ 未检测到 GPU,正在使用 CPU 模式(速度较慢)...")
25
- model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float()
26
 
27
- model.eval()
28
- print(" 模型加载成功!")
29
 
30
- # --- 1. 给别人用的 API 接口 ---
31
  @app.post("/generate")
32
  async def generate(request: Request):
33
  try:
34
  data = await request.json()
35
  prompt = data.get("text", "")
36
- # 模型推理
37
- response, _ = model.chat(tokenizer, prompt, history=[])
38
- return {"success": True, "result": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
 
40
  return {"success": False, "error": str(e)}
41
 
42
- # --- 2. 给自己用的网页界面 ---
43
- def chat_func(msg, hist):
44
- res, _ = model.chat(tokenizer, msg, history=hist)
45
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
47
  demo = gr.ChatInterface(
48
- fn=chat_func,
49
- title="ChatGLM 广告生成助手",
50
- description="本程序已同时开启 API 接口(路径:/generate)"
51
  )
52
 
53
- # 挂载 Gradio 到 FastAPI
54
  app = gr.mount_gradio_app(app, demo, path="/")
55
 
 
 
 
 
 
56
  if __name__ == "__main__":
57
- # 启动服务器,默认端口 7860
58
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import gradio as gr
2
  from fastapi import FastAPI, Request
3
  import uvicorn
 
 
4
  import os
5
+ from huggingface_hub import InferenceClient
6
+ import logging
7
 
8
+ # 设置日志
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
 
 
 
11
 
12
+ # 初始化 FastAPI
13
+ app = FastAPI()
14
 
15
+ # 从环境变量获取 Hugging Face Token
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ if not HF_TOKEN:
18
+ logger.warning("⚠️ 未设置 HF_TOKEN 环境变量,API 调用可能失败")
19
 
20
+ # 初始化 InferenceClient
21
+ # 这里用 provider="auto" 让系统自动选择可用提供商
22
+ client = InferenceClient(provider="auto", token=HF_TOKEN)
 
 
 
 
23
 
24
+ # 你的模型名称
25
+ MODEL_ID = "jiang1002/chatglm-6b-adgen" # 或者换成其他公开模型
26
 
27
+ # --- 1. API 接口 ---
28
  @app.post("/generate")
29
  async def generate(request: Request):
30
  try:
31
  data = await request.json()
32
  prompt = data.get("text", "")
33
+ messages = data.get("messages", [])
34
+
35
+ # 如果提供了完整的 messages 格式,就用它
36
+ if messages:
37
+ response = client.chat.completions.create(
38
+ model=MODEL_ID,
39
+ messages=messages
40
+ )
41
+ result = response.choices[0].message.content
42
+ else:
43
+ # 否则用简单的 prompt 格式
44
+ response = client.text_generation(
45
+ prompt,
46
+ model=MODEL_ID,
47
+ max_new_tokens=512,
48
+ temperature=0.7
49
+ )
50
+ result = response
51
+
52
+ return {"success": True, "result": result}
53
  except Exception as e:
54
+ logger.error(f"API 调用失败: {str(e)}")
55
  return {"success": False, "error": str(e)}
56
 
57
+ # --- 2. Gradio 聊天界面 ---
58
+ def chat_func(message, history):
59
+ """Gradio 聊天函数"""
60
+ try:
61
+ # 将历史记录转换为 messages 格式
62
+ messages = []
63
+ for human, assistant in history:
64
+ messages.append({"role": "user", "content": human})
65
+ messages.append({"role": "assistant", "content": assistant})
66
+ messages.append({"role": "user", "content": message})
67
+
68
+ # 调用 Inference API
69
+ response = client.chat.completions.create(
70
+ model=MODEL_ID,
71
+ messages=messages,
72
+ max_tokens=512,
73
+ temperature=0.7
74
+ )
75
+
76
+ return response.choices[0].message.content
77
+ except Exception as e:
78
+ logger.error(f"聊天失败: {str(e)}")
79
+ return f"调用失败: {str(e)}"
80
 
81
+ # 创建 Gradio 界面
82
  demo = gr.ChatInterface(
83
+ fn=chat_func,
84
+ title="ChatGLM 广告生成助手 (使用 Inference Providers)",
85
+ description="后台使用 Hugging Face Inference Providers,无需本地 GPU"
86
  )
87
 
88
+ # 挂载 Gradio
89
  app = gr.mount_gradio_app(app, demo, path="/")
90
 
91
+ # 添加健康检查端点
92
+ @app.get("/health")
93
+ async def health():
94
+ return {"status": "ok", "model": MODEL_ID}
95
+
96
  if __name__ == "__main__":
97
+ port = int(os.getenv("PORT", 7860))
98
+ uvicorn.run(app, host="0.0.0.0", port=port)