Spaces:
Sleeping
Sleeping
File size: 3,220 Bytes
171f2ef 8fca8cb d1912a0 171f2ef 21916d9 d1912a0 21916d9 171f2ef d1912a0 21916d9 d1912a0 21916d9 171f2ef 21916d9 171f2ef 21916d9 171f2ef 21916d9 171f2ef d1912a0 171f2ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import operator
from typing import Annotated, Literal
from typing_extensions import TypedDict
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, END
# 定义图的状态
class GraphState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
# 只加载一次模型和分词器
MODEL_NAME = "inclusionAI/Ring-mini-2.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True
).to(device)
# 定义图的节点
def call_model(state: GraphState):
"""模型调用节点"""
messages = state["messages"]
# 拼接 prompt
prompt = ""
for msg in messages:
if msg.type == "system":
prompt += f"{msg.content}\n"
elif msg.type == "human":
prompt += f"User: {msg.content}\n"
elif msg.type == "ai":
prompt += f"Assistant: {msg.content}\n"
prompt += "Assistant:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
output_ids = model.generate(
input_ids,
max_new_tokens=512, # 暂时硬编码
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
return {"messages": [AIMessage(content=output)]}
# 构建图
workflow = StateGraph(GraphState)
workflow.add_node("llm", call_model)
workflow.set_entry_point("llm")
workflow.add_edge("llm", END)
# 编译图
app = workflow.compile()
@spaces.GPU
def respond(message, history, system_message, hf_token: gr.OAuthToken = None):
"""Gradio 接口的响应函数,调用 LangGraph 应用"""
# 将 Gradio 的 history 格式转换为 LangChain 消息格式
messages = []
if system_message:
messages.append(SystemMessage(content=system_message))
for turn in history:
user_message, bot_message = turn
if user_message:
messages.append(HumanMessage(content=user_message))
if bot_message:
messages.append(AIMessage(content=bot_message))
messages.append(HumanMessage(content=message))
# 使用 invoke 方法进行一次性调用
inputs = {"messages": messages}
final_state = app.invoke(inputs)
# 从最终状态中提取最后一条消息
final_response = final_state["messages"][-1].content
return final_response
# 重新定义 ChatInterface
chatbot = gr.ChatInterface(
respond,
type="messages", # 改为 messages 类型以更好地匹配 LangChain
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
],
)
with gr.Blocks() as demo:
gr.Markdown("# HuggingFace Running")
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()
|