File size: 5,704 Bytes
2e301f3
6e6d483
 
615868a
c88a9bb
 
6982706
b921357
c88a9bb
 
 
615868a
8f97a31
 
 
 
e6e51ad
 
 
8f97a31
 
 
 
 
 
 
e6e51ad
8f97a31
 
 
6982706
8f97a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6982706
8f97a31
6982706
8f97a31
 
 
 
6982706
8f97a31
6982706
8f97a31
6982706
8f97a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615868a
8f97a31
 
 
 
6e6d483
 
5368ce1
 
 
c88a9bb
 
8f97a31
c88a9bb
 
 
 
 
 
 
8f97a31
 
 
c88a9bb
8f97a31
 
c88a9bb
8f97a31
c88a9bb
 
5368ce1
c88a9bb
5368ce1
 
8f97a31
c88a9bb
 
 
8f97a31
c88a9bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f97a31
 
6e6d483
c88a9bb
615868a
c88a9bb
8f97a31
 
615868a
 
c88a9bb
6e6d483
2e301f3
8f97a31
c88a9bb
 
 
 
2e301f3
c88a9bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
from fastapi import FastAPI, Request
import uvicorn
import os
from huggingface_hub import InferenceClient
import logging
import traceback

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ===== 测试代码开始 =====
print("="*50)
print("🔍 开始测试模型调用")
print("="*50)

HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    print("❌ 错误: HF_TOKEN 环境变量未设置!")
else:
    print(f"✅ HF_TOKEN 已设置 (长度: {len(HF_TOKEN)})")
    if HF_TOKEN.startswith("hf_"):
        print("✅ HF_TOKEN 格式正确")
    else:
        print("⚠️ 警告: HF_TOKEN 格式可能不正确,应以 hf_ 开头")

# 测试你的模型
model_id = "jiang1002/chatglm-6b-adgen"
print(f"\n📊 正在测试模型 '{model_id}'...")

# 方法1:测试 Hugging Face 免费推理(不指定 provider)
try:
    print("\n🔄 测试1: 使用 Hugging Face 免费推理...")
    client1 = InferenceClient(token=HF_TOKEN)
    response1 = client1.text_generation(
        "你好",
        model=model_id,
        max_new_tokens=20
    )
    print(f"✅ 免费推理成功! 响应: {response1[:50]}...")
except Exception as e:
    print(f"❌ 免费推理失败: {str(e)}")

# 方法2:测试 auto provider(自动选择)
try:
    print("\n🔄 测试2: 使用 auto provider...")
    client2 = InferenceClient(provider="auto", token=HF_TOKEN)
    response2 = client2.chat.completions.create(
        model=model_id,
        messages=[{"role": "user", "content": "你好"}],
        max_tokens=20
    )
    print(f"✅ auto provider 成功! 响应: {response2.choices[0].message.content[:50]}...")
except Exception as e:
    print(f"❌ auto provider 失败: {str(e)}")

# 方法3:测试 Groq(如果配置了)
try:
    print("\n🔄 测试3: 使用 Groq...")
    client3 = InferenceClient(provider="groq", token=HF_TOKEN)
    response3 = client3.chat.completions.create(
        model=model_id,
        messages=[{"role": "user", "content": "你好"}],
        max_tokens=20
    )
    print(f"✅ Groq 成功! 响应: {response3.choices[0].message.content[:50]}...")
except Exception as e:
    print(f"❌ Groq 失败: {str(e)}")

# 方法4:测试 Together AI(如果配置了)
try:
    print("\n🔄 测试4: 使用 Together AI...")
    client4 = InferenceClient(provider="together-ai", token=HF_TOKEN)
    response4 = client4.chat.completions.create(
        model=model_id,
        messages=[{"role": "user", "content": "你好"}],
        max_tokens=20
    )
    print(f"✅ Together AI 成功! 响应: {response4.choices[0].message.content[:50]}...")
except Exception as e:
    print(f"❌ Together AI 失败: {str(e)}")

print("\n" + "="*50)
print("🔍 测试结束,继续启动应用...")
print("="*50)
# ===== 测试代码结束 =====

# 初始化 FastAPI
app = FastAPI()

# 从环境变量获取 Hugging Face Token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    logger.warning("⚠️ 未设置 HF_TOKEN 环境变量,API 调用可能失败")

# 初始化 InferenceClient
# 这里用 provider="auto" 让系统自动选择可用提供商
client = InferenceClient(provider="auto", token=HF_TOKEN)

# 你的模型名称
MODEL_ID = "jiang1002/chatglm-6b-adgen"  # 或者换成其他公开模型

# --- 1. API 接口 ---
@app.post("/generate")
async def generate(request: Request):
    try:
        data = await request.json()
        prompt = data.get("text", "")
        messages = data.get("messages", [])
        
        # 如果提供了完整的 messages 格式,就用它
        if messages:
            response = client.chat.completions.create(
                model=MODEL_ID,
                messages=messages
            )
            result = response.choices[0].message.content
        else:
            # 否则用简单的 prompt 格式
            response = client.text_generation(
                prompt,
                model=MODEL_ID,
                max_new_tokens=512,
                temperature=0.7
            )
            result = response
            
        return {"success": True, "result": result}
    except Exception as e:
        logger.error(f"API 调用失败: {str(e)}")
        return {"success": False, "error": str(e)}

# --- 2. Gradio 聊天界面 ---
def chat_func(message, history):
    """Gradio 聊天函数"""
    try:
        # 将历史记录转换为 messages 格式
        messages = []
        for human, assistant in history:
            messages.append({"role": "user", "content": human})
            messages.append({"role": "assistant", "content": assistant})
        messages.append({"role": "user", "content": message})
        
        # 调用 Inference API
        response = client.chat.completions.create(
            model=MODEL_ID,
            messages=messages,
            max_tokens=512,
            temperature=0.7
        )
        
        return response.choices[0].message.content
    except Exception as e:
        logger.error(f"聊天失败: {str(e)}")
        logger.error(f"详细错误: {traceback.format_exc()}")  # 添加这行
        return f"调用失败: {str(e)}\n\n{traceback.format_exc()}"

# 创建 Gradio 界面
demo = gr.ChatInterface(
    fn=chat_func,
    title="ChatGLM 广告生成助手 (使用 Inference Providers)",
    description="后台使用 Hugging Face Inference Providers,无需本地 GPU"
)

# 挂载 Gradio
app = gr.mount_gradio_app(app, demo, path="/")

# 添加健康检查端点
@app.get("/health")
async def health():
    return {"status": "ok", "model": MODEL_ID}

if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)