import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer import operator from typing import Annotated, Literal from typing_extensions import TypedDict from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, HumanMessage, ToolMessage from langgraph.graph import StateGraph, END # 定义图的状态 class GraphState(TypedDict): messages: Annotated[list[AnyMessage], operator.add] # 只加载一次模型和分词器 MODEL_NAME = "inclusionAI/Ring-mini-2.0" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if device == "cuda" else torch.float32, trust_remote_code=True ).to(device) # 定义图的节点 def call_model(state: GraphState): """模型调用节点""" messages = state["messages"] # 拼接 prompt prompt = "" for msg in messages: if msg.type == "system": prompt += f"{msg.content}\n" elif msg.type == "human": prompt += f"User: {msg.content}\n" elif msg.type == "ai": prompt += f"Assistant: {msg.content}\n" prompt += "Assistant:" input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) output_ids = model.generate( input_ids, max_new_tokens=512, # 暂时硬编码 do_sample=True, pad_token_id=tokenizer.eos_token_id, ) output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) return {"messages": [AIMessage(content=output)]} # 构建图 workflow = StateGraph(GraphState) workflow.add_node("llm", call_model) workflow.set_entry_point("llm") workflow.add_edge("llm", END) # 编译图 app = workflow.compile() @spaces.GPU def respond(message, history, system_message, hf_token: gr.OAuthToken = None): """Gradio 接口的响应函数,调用 LangGraph 应用""" # 将 Gradio 的 history 格式转换为 LangChain 消息格式 messages = [] if system_message: messages.append(SystemMessage(content=system_message)) for turn in history: user_message, bot_message = turn if user_message: messages.append(HumanMessage(content=user_message)) if bot_message: messages.append(AIMessage(content=bot_message)) messages.append(HumanMessage(content=message)) # 使用 invoke 方法进行一次性调用 inputs = {"messages": messages} final_state = app.invoke(inputs) # 从最终状态中提取最后一条消息 final_response = final_state["messages"][-1].content return final_response # 重新定义 ChatInterface chatbot = gr.ChatInterface( respond, type="messages", # 改为 messages 类型以更好地匹配 LangChain additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), ], ) with gr.Blocks() as demo: gr.Markdown("# HuggingFace Running") with gr.Sidebar(): gr.LoginButton() chatbot.render() if __name__ == "__main__": demo.launch()