Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import operator | |
| from typing import Annotated, Literal | |
| from typing_extensions import TypedDict | |
| from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, HumanMessage, ToolMessage | |
| from langgraph.graph import StateGraph, END | |
| # 定义图的状态 | |
| class GraphState(TypedDict): | |
| messages: Annotated[list[AnyMessage], operator.add] | |
| # 只加载一次模型和分词器 | |
| MODEL_NAME = "inclusionAI/Ring-mini-2.0" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| trust_remote_code=True | |
| ).to(device) | |
| # 定义图的节点 | |
| def call_model(state: GraphState): | |
| """模型调用节点""" | |
| messages = state["messages"] | |
| # 拼接 prompt | |
| prompt = "" | |
| for msg in messages: | |
| if msg.type == "system": | |
| prompt += f"{msg.content}\n" | |
| elif msg.type == "human": | |
| prompt += f"User: {msg.content}\n" | |
| elif msg.type == "ai": | |
| prompt += f"Assistant: {msg.content}\n" | |
| prompt += "Assistant:" | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=512, # 暂时硬编码 | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| return {"messages": [AIMessage(content=output)]} | |
| # 构建图 | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node("llm", call_model) | |
| workflow.set_entry_point("llm") | |
| workflow.add_edge("llm", END) | |
| # 编译图 | |
| app = workflow.compile() | |
| def respond(message, history, system_message, hf_token: gr.OAuthToken = None): | |
| """Gradio 接口的响应函数,调用 LangGraph 应用""" | |
| # 将 Gradio 的 history 格式转换为 LangChain 消息格式 | |
| messages = [] | |
| if system_message: | |
| messages.append(SystemMessage(content=system_message)) | |
| for turn in history: | |
| user_message, bot_message = turn | |
| if user_message: | |
| messages.append(HumanMessage(content=user_message)) | |
| if bot_message: | |
| messages.append(AIMessage(content=bot_message)) | |
| messages.append(HumanMessage(content=message)) | |
| # 使用 invoke 方法进行一次性调用 | |
| inputs = {"messages": messages} | |
| final_state = app.invoke(inputs) | |
| # 从最终状态中提取最后一条消息 | |
| final_response = final_state["messages"][-1].content | |
| return final_response | |
| # 重新定义 ChatInterface | |
| chatbot = gr.ChatInterface( | |
| respond, | |
| type="messages", # 改为 messages 类型以更好地匹配 LangChain | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| ], | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# HuggingFace Running") | |
| with gr.Sidebar(): | |
| gr.LoginButton() | |
| chatbot.render() | |
| if __name__ == "__main__": | |
| demo.launch() | |