File size: 3,220 Bytes
171f2ef
8fca8cb
d1912a0
 
171f2ef
21916d9
 
 
 
 
 
 
 
 
 
 
 
 
d1912a0
 
 
 
 
 
 
 
 
 
21916d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171f2ef
d1912a0
 
 
21916d9
d1912a0
 
 
 
21916d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171f2ef
21916d9
 
 
 
 
 
 
 
171f2ef
21916d9
171f2ef
 
21916d9
171f2ef
 
 
 
 
 
d1912a0
171f2ef
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import operator
from typing import Annotated, Literal
from typing_extensions import TypedDict

from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, END


# 定义图的状态
class GraphState(TypedDict):
    messages: Annotated[list[AnyMessage], operator.add]


# 只加载一次模型和分词器
MODEL_NAME = "inclusionAI/Ring-mini-2.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    trust_remote_code=True
).to(device)


# 定义图的节点
def call_model(state: GraphState):
    """模型调用节点"""
    messages = state["messages"]

    # 拼接 prompt
    prompt = ""
    for msg in messages:
        if msg.type == "system":
            prompt += f"{msg.content}\n"
        elif msg.type == "human":
            prompt += f"User: {msg.content}\n"
        elif msg.type == "ai":
            prompt += f"Assistant: {msg.content}\n"
    prompt += "Assistant:"

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    output_ids = model.generate(
        input_ids,
        max_new_tokens=512, # 暂时硬编码
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
    
    return {"messages": [AIMessage(content=output)]}

# 构建图
workflow = StateGraph(GraphState)
workflow.add_node("llm", call_model)
workflow.set_entry_point("llm")
workflow.add_edge("llm", END)

# 编译图
app = workflow.compile()
@spaces.GPU
def respond(message, history, system_message, hf_token: gr.OAuthToken = None):
    """Gradio 接口的响应函数,调用 LangGraph 应用"""
    
    # 将 Gradio 的 history 格式转换为 LangChain 消息格式
    messages = []
    if system_message:
        messages.append(SystemMessage(content=system_message))
    
    for turn in history:
        user_message, bot_message = turn
        if user_message:
            messages.append(HumanMessage(content=user_message))
        if bot_message:
            messages.append(AIMessage(content=bot_message))
            
    messages.append(HumanMessage(content=message))

    # 使用 invoke 方法进行一次性调用
    inputs = {"messages": messages}
    final_state = app.invoke(inputs)
    
    # 从最终状态中提取最后一条消息
    final_response = final_state["messages"][-1].content
    
    return final_response

# 重新定义 ChatInterface
chatbot = gr.ChatInterface(
    respond,
    type="messages", # 改为 messages 类型以更好地匹配 LangChain
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
    ],
)

with gr.Blocks() as demo:
    gr.Markdown("# HuggingFace Running")
    with gr.Sidebar():
        gr.LoginButton()
    chatbot.render()


if __name__ == "__main__":
    demo.launch()