chatglm-6b-api / app.py
jiang1002's picture
Update app.py
8f97a31 verified
import gradio as gr
from fastapi import FastAPI, Request
import uvicorn
import os
from huggingface_hub import InferenceClient
import logging
import traceback
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ===== 测试代码开始 =====
print("="*50)
print("🔍 开始测试模型调用")
print("="*50)
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("❌ 错误: HF_TOKEN 环境变量未设置!")
else:
print(f"✅ HF_TOKEN 已设置 (长度: {len(HF_TOKEN)})")
if HF_TOKEN.startswith("hf_"):
print("✅ HF_TOKEN 格式正确")
else:
print("⚠️ 警告: HF_TOKEN 格式可能不正确,应以 hf_ 开头")
# 测试你的模型
model_id = "jiang1002/chatglm-6b-adgen"
print(f"\n📊 正在测试模型 '{model_id}'...")
# 方法1:测试 Hugging Face 免费推理(不指定 provider)
try:
print("\n🔄 测试1: 使用 Hugging Face 免费推理...")
client1 = InferenceClient(token=HF_TOKEN)
response1 = client1.text_generation(
"你好",
model=model_id,
max_new_tokens=20
)
print(f"✅ 免费推理成功! 响应: {response1[:50]}...")
except Exception as e:
print(f"❌ 免费推理失败: {str(e)}")
# 方法2:测试 auto provider(自动选择)
try:
print("\n🔄 测试2: 使用 auto provider...")
client2 = InferenceClient(provider="auto", token=HF_TOKEN)
response2 = client2.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": "你好"}],
max_tokens=20
)
print(f"✅ auto provider 成功! 响应: {response2.choices[0].message.content[:50]}...")
except Exception as e:
print(f"❌ auto provider 失败: {str(e)}")
# 方法3:测试 Groq(如果配置了)
try:
print("\n🔄 测试3: 使用 Groq...")
client3 = InferenceClient(provider="groq", token=HF_TOKEN)
response3 = client3.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": "你好"}],
max_tokens=20
)
print(f"✅ Groq 成功! 响应: {response3.choices[0].message.content[:50]}...")
except Exception as e:
print(f"❌ Groq 失败: {str(e)}")
# 方法4:测试 Together AI(如果配置了)
try:
print("\n🔄 测试4: 使用 Together AI...")
client4 = InferenceClient(provider="together-ai", token=HF_TOKEN)
response4 = client4.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": "你好"}],
max_tokens=20
)
print(f"✅ Together AI 成功! 响应: {response4.choices[0].message.content[:50]}...")
except Exception as e:
print(f"❌ Together AI 失败: {str(e)}")
print("\n" + "="*50)
print("🔍 测试结束,继续启动应用...")
print("="*50)
# ===== 测试代码结束 =====
# 初始化 FastAPI
app = FastAPI()
# 从环境变量获取 Hugging Face Token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
logger.warning("⚠️ 未设置 HF_TOKEN 环境变量,API 调用可能失败")
# 初始化 InferenceClient
# 这里用 provider="auto" 让系统自动选择可用提供商
client = InferenceClient(provider="auto", token=HF_TOKEN)
# 你的模型名称
MODEL_ID = "jiang1002/chatglm-6b-adgen" # 或者换成其他公开模型
# --- 1. API 接口 ---
@app.post("/generate")
async def generate(request: Request):
try:
data = await request.json()
prompt = data.get("text", "")
messages = data.get("messages", [])
# 如果提供了完整的 messages 格式,就用它
if messages:
response = client.chat.completions.create(
model=MODEL_ID,
messages=messages
)
result = response.choices[0].message.content
else:
# 否则用简单的 prompt 格式
response = client.text_generation(
prompt,
model=MODEL_ID,
max_new_tokens=512,
temperature=0.7
)
result = response
return {"success": True, "result": result}
except Exception as e:
logger.error(f"API 调用失败: {str(e)}")
return {"success": False, "error": str(e)}
# --- 2. Gradio 聊天界面 ---
def chat_func(message, history):
"""Gradio 聊天函数"""
try:
# 将历史记录转换为 messages 格式
messages = []
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
# 调用 Inference API
response = client.chat.completions.create(
model=MODEL_ID,
messages=messages,
max_tokens=512,
temperature=0.7
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"聊天失败: {str(e)}")
logger.error(f"详细错误: {traceback.format_exc()}") # 添加这行
return f"调用失败: {str(e)}\n\n{traceback.format_exc()}"
# 创建 Gradio 界面
demo = gr.ChatInterface(
fn=chat_func,
title="ChatGLM 广告生成助手 (使用 Inference Providers)",
description="后台使用 Hugging Face Inference Providers,无需本地 GPU"
)
# 挂载 Gradio
app = gr.mount_gradio_app(app, demo, path="/")
# 添加健康检查端点
@app.get("/health")
async def health():
return {"status": "ok", "model": MODEL_ID}
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)