import gradio as gr from fastapi import FastAPI, Request import uvicorn import os from huggingface_hub import InferenceClient import logging import traceback # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ===== 测试代码开始 ===== print("="*50) print("🔍 开始测试模型调用") print("="*50) HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: print("❌ 错误: HF_TOKEN 环境变量未设置!") else: print(f"✅ HF_TOKEN 已设置 (长度: {len(HF_TOKEN)})") if HF_TOKEN.startswith("hf_"): print("✅ HF_TOKEN 格式正确") else: print("⚠️ 警告: HF_TOKEN 格式可能不正确,应以 hf_ 开头") # 测试你的模型 model_id = "jiang1002/chatglm-6b-adgen" print(f"\n📊 正在测试模型 '{model_id}'...") # 方法1:测试 Hugging Face 免费推理(不指定 provider) try: print("\n🔄 测试1: 使用 Hugging Face 免费推理...") client1 = InferenceClient(token=HF_TOKEN) response1 = client1.text_generation( "你好", model=model_id, max_new_tokens=20 ) print(f"✅ 免费推理成功! 响应: {response1[:50]}...") except Exception as e: print(f"❌ 免费推理失败: {str(e)}") # 方法2:测试 auto provider(自动选择) try: print("\n🔄 测试2: 使用 auto provider...") client2 = InferenceClient(provider="auto", token=HF_TOKEN) response2 = client2.chat.completions.create( model=model_id, messages=[{"role": "user", "content": "你好"}], max_tokens=20 ) print(f"✅ auto provider 成功! 响应: {response2.choices[0].message.content[:50]}...") except Exception as e: print(f"❌ auto provider 失败: {str(e)}") # 方法3:测试 Groq(如果配置了) try: print("\n🔄 测试3: 使用 Groq...") client3 = InferenceClient(provider="groq", token=HF_TOKEN) response3 = client3.chat.completions.create( model=model_id, messages=[{"role": "user", "content": "你好"}], max_tokens=20 ) print(f"✅ Groq 成功! 响应: {response3.choices[0].message.content[:50]}...") except Exception as e: print(f"❌ Groq 失败: {str(e)}") # 方法4:测试 Together AI(如果配置了) try: print("\n🔄 测试4: 使用 Together AI...") client4 = InferenceClient(provider="together-ai", token=HF_TOKEN) response4 = client4.chat.completions.create( model=model_id, messages=[{"role": "user", "content": "你好"}], max_tokens=20 ) print(f"✅ Together AI 成功! 响应: {response4.choices[0].message.content[:50]}...") except Exception as e: print(f"❌ Together AI 失败: {str(e)}") print("\n" + "="*50) print("🔍 测试结束,继续启动应用...") print("="*50) # ===== 测试代码结束 ===== # 初始化 FastAPI app = FastAPI() # 从环境变量获取 Hugging Face Token HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: logger.warning("⚠️ 未设置 HF_TOKEN 环境变量,API 调用可能失败") # 初始化 InferenceClient # 这里用 provider="auto" 让系统自动选择可用提供商 client = InferenceClient(provider="auto", token=HF_TOKEN) # 你的模型名称 MODEL_ID = "jiang1002/chatglm-6b-adgen" # 或者换成其他公开模型 # --- 1. API 接口 --- @app.post("/generate") async def generate(request: Request): try: data = await request.json() prompt = data.get("text", "") messages = data.get("messages", []) # 如果提供了完整的 messages 格式,就用它 if messages: response = client.chat.completions.create( model=MODEL_ID, messages=messages ) result = response.choices[0].message.content else: # 否则用简单的 prompt 格式 response = client.text_generation( prompt, model=MODEL_ID, max_new_tokens=512, temperature=0.7 ) result = response return {"success": True, "result": result} except Exception as e: logger.error(f"API 调用失败: {str(e)}") return {"success": False, "error": str(e)} # --- 2. Gradio 聊天界面 --- def chat_func(message, history): """Gradio 聊天函数""" try: # 将历史记录转换为 messages 格式 messages = [] for human, assistant in history: messages.append({"role": "user", "content": human}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": message}) # 调用 Inference API response = client.chat.completions.create( model=MODEL_ID, messages=messages, max_tokens=512, temperature=0.7 ) return response.choices[0].message.content except Exception as e: logger.error(f"聊天失败: {str(e)}") logger.error(f"详细错误: {traceback.format_exc()}") # 添加这行 return f"调用失败: {str(e)}\n\n{traceback.format_exc()}" # 创建 Gradio 界面 demo = gr.ChatInterface( fn=chat_func, title="ChatGLM 广告生成助手 (使用 Inference Providers)", description="后台使用 Hugging Face Inference Providers,无需本地 GPU" ) # 挂载 Gradio app = gr.mount_gradio_app(app, demo, path="/") # 添加健康检查端点 @app.get("/health") async def health(): return {"status": "ok", "model": MODEL_ID} if __name__ == "__main__": port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)