cafe3310's picture
初始化 LangGraph 聊天应用
21916d9
raw
history blame
3.22 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import operator
from typing import Annotated, Literal
from typing_extensions import TypedDict
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, END
# 定义图的状态
class GraphState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
# 只加载一次模型和分词器
MODEL_NAME = "inclusionAI/Ring-mini-2.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True
).to(device)
# 定义图的节点
def call_model(state: GraphState):
"""模型调用节点"""
messages = state["messages"]
# 拼接 prompt
prompt = ""
for msg in messages:
if msg.type == "system":
prompt += f"{msg.content}\n"
elif msg.type == "human":
prompt += f"User: {msg.content}\n"
elif msg.type == "ai":
prompt += f"Assistant: {msg.content}\n"
prompt += "Assistant:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output_ids = model.generate(
input_ids,
max_new_tokens=512, # 暂时硬编码
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
return {"messages": [AIMessage(content=output)]}
# 构建图
workflow = StateGraph(GraphState)
workflow.add_node("llm", call_model)
workflow.set_entry_point("llm")
workflow.add_edge("llm", END)
# 编译图
app = workflow.compile()
@spaces.GPU
def respond(message, history, system_message, hf_token: gr.OAuthToken = None):
"""Gradio 接口的响应函数,调用 LangGraph 应用"""
# 将 Gradio 的 history 格式转换为 LangChain 消息格式
messages = []
if system_message:
messages.append(SystemMessage(content=system_message))
for turn in history:
user_message, bot_message = turn
if user_message:
messages.append(HumanMessage(content=user_message))
if bot_message:
messages.append(AIMessage(content=bot_message))
messages.append(HumanMessage(content=message))
# 使用 invoke 方法进行一次性调用
inputs = {"messages": messages}
final_state = app.invoke(inputs)
# 从最终状态中提取最后一条消息
final_response = final_state["messages"][-1].content
return final_response
# 重新定义 ChatInterface
chatbot = gr.ChatInterface(
respond,
type="messages", # 改为 messages 类型以更好地匹配 LangChain
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
],
)
with gr.Blocks() as demo:
gr.Markdown("# HuggingFace Running")
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()