badanwang commited on
Commit
c4614a3
·
verified ·
1 Parent(s): 9eb5e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -151
app.py CHANGED
@@ -1,165 +1,102 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- from threading import Thread
5
  import os
 
6
 
7
  # --- 配置 ---
8
- MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
9
- print(f"INFO: Application startup. Loading model: {MODEL_ID}")
 
 
10
 
11
- # --- 1. 模型加载 (内置健壮的错误处理) ---
12
- try:
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- MODEL_ID,
16
- torch_dtype="auto",
17
- device_map="auto",
18
- trust_remote_code=True
19
- )
20
- print("INFO: Model and tokenizer loaded successfully!")
21
- model_loaded = True
22
- except Exception as e:
23
- print(f"FATAL: Failed to load model or tokenizer: {e}")
24
- model_loaded = False
25
- model_load_error = e
26
-
27
- # --- 2. 核心流式推理函数 ---
28
- def stream_predict(prompt: str, history: list[list[str]]):
29
  """
30
- 一个生成器函数,用于流式生成对话。
31
- 它会逐步 (yield) 返回完整的对话历史。
 
 
32
  """
33
- if not model_loaded:
34
- # 如果模型加载失败,则立即抛出错误
35
- raise gr.Error(f"Model is not loaded. Please check logs. Error: {model_load_error}")
36
 
37
- print(f"INFO: Received prompt: '{prompt}'")
38
-
39
- # 将历史记录和新提示转换为模型需要的格式
 
 
 
 
40
  messages = []
41
- for user_msg, assistant_msg in history:
 
42
  messages.append({"role": "user", "content": user_msg})
43
  messages.append({"role": "assistant", "content": assistant_msg})
44
- messages.append({"role": "user", "content": prompt})
45
-
46
- # 应用聊天模板
47
- try:
48
- input_ids = tokenizer.apply_chat_template(
49
- messages,
50
- add_generation_prompt=True,
51
- tokenize=True,
52
- return_tensors="pt"
53
- ).to(model.device)
54
- except Exception as e:
55
- raise gr.Error(f"Error applying chat template: {e}")
56
-
57
- # 初始化 streamer 和生成线程
58
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
59
-
60
- generation_kwargs = dict(
61
- input_ids=input_ids,
62
- streamer=streamer,
63
- max_new_tokens=1024,
64
- do_sample=True,
65
- temperature=0.7,
66
- top_p=0.9
67
- )
68
 
69
- # 在独立线程中运行生成,防止阻塞UI
70
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
- thread.start()
72
-
73
- # 流式输出
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
- # 初始化一个空的字符串来存放助手的回复
76
- assistant_response = ""
77
- # 每次从streamer中获取一个新的文本片段
78
- for new_text in streamer:
79
- if not new_text:
80
- continue
81
- assistant_response += new_text
82
- # 将当前用户输入和不断增长的助手回复组合成新的对话历史
83
- # 然后使用 yield 返回,Gradio会用它来更新UI
84
- yield history + [[prompt, assistant_response]]
85
-
86
- print("INFO: Streaming finished.")
87
-
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
- print(f"ERROR: An error occurred during streaming: {e}")
90
- raise gr.Error(f"An error occurred during generation: {e}")
91
- finally:
92
- # 确保线程结束
93
- thread.join()
94
-
95
-
96
- # --- 3. Gradio Blocks 界面布局 ---
97
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), css="footer {visibility: hidden}") as demo:
98
- gr.Markdown(f"# 流式对话机器人\n### 模型: `{MODEL_ID}`")
99
-
100
- # 使用 gr.State 来存储对话历史
101
- # 这是实现多轮对话的关键
102
- chatbot_state = gr.State([])
103
-
104
- # Chatbot 组件用于显示对话
105
- chatbot_ui = gr.Chatbot(label="对话窗口", height=600)
106
-
107
- with gr.Row():
108
- # Textbox 用于用户输入
109
- prompt_input = gr.Textbox(
110
- show_label=False,
111
- placeholder="请在这里输入您的问题...",
112
- scale=4,
113
- )
114
- # Button 用于提交
115
- submit_button = gr.Button("发送", variant="primary", scale=1)
116
-
117
- # 清除按钮
118
- clear_button = gr.Button("清除对话历史")
119
-
120
- # --- 4. 事件处理逻辑 ---
121
-
122
- # 提交逻辑:
123
- # 1. 点击"发送"按钮或在输入框按回车时触发
124
- # 2. 调用 stream_predict 函数
125
- # 3. 输入是用户输入框(prompt_input)和对话历史状态(chatbot_state)
126
- # 4. 输出会实时更新聊天机器人界面(chatbot_ui)
127
- # 5. 在函数开始前,将用户输入添加到聊天记录的末尾,并清空输入框
128
- def on_submit(prompt, history):
129
- # 将用户输入加入历史,形成 "用户: XXX" 的临时记录
130
- return "", history + [[prompt, None]]
131
-
132
- prompt_input.submit(
133
- on_submit,
134
- [prompt_input, chatbot_state],
135
- [prompt_input, chatbot_ui]
136
- ).then(
137
- stream_predict,
138
- [prompt_input, chatbot_state],
139
- chatbot_ui
140
- )
141
-
142
- submit_button.click(
143
- on_submit,
144
- [prompt_input, chatbot_state],
145
- [prompt_input, chatbot_ui]
146
- ).then(
147
- stream_predict,
148
- [prompt_input, chatbot_state],
149
- chatbot_ui
150
- )
151
-
152
- # 清除逻辑:
153
- # 点击按钮时,清空状态和UI
154
- def on_clear():
155
- return []
156
-
157
- clear_button.click(on_clear, [], chatbot_state)
158
- clear_button.click(on_clear, [], chatbot_ui)
159
-
160
-
161
- # --- 5. 启动应用 ---
162
- print("INFO: Preparing to launch Gradio app...")
163
- # .queue() 启用请求队列,对于流式应用是必需的
164
- # 在Hugging Face Spaces上, 无需 share=True, Gradio会自动处理
165
- demo.queue().launch()
 
1
  import gradio as gr
2
+ import requests
 
 
3
  import os
4
+ import json
5
 
6
  # --- 配置 ---
7
+ # 从Hugging Face Space的Secrets中获取API Token
8
+ # 请确保在你的Space设置中添加了名为 "HF_TOKEN" 的Secret
9
+ HF_TOKEN = os.getenv("HF_TOKEN")
10
+ API_URL = "https://api-inference.huggingface.co/models/badanwang/teacher_basic_qwen3-0.6b"
11
 
12
+ # --- 核心对话函数 ---
13
+ def predict(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
+ 主函数,用于与Hugging Face Inference API进行流式对话。
16
+ :param message: 用户当前发送的消息 (str)
17
+ :param history: 对话历史 (list of lists),格式为 [[user_msg, assistant_msg], ...]
18
+ :return: 一个生成器 (generator),逐字(token)返回模型的响应
19
  """
20
+ if not HF_TOKEN:
21
+ raise gr.Error("Hugging Face API Token 未配置!请在Space的Secrets中添加 HF_TOKEN。")
 
22
 
23
+ headers = {
24
+ "Authorization": f"Bearer {HF_TOKEN}",
25
+ "Content-Type": "application/json"
26
+ }
27
+
28
+ # 1. 格式化对话历史以符合API要求
29
+ # API需要一个包含所有对话的列表,格式为 {"role": "user", "content": "..."} 或 {"role": "assistant", "content": "..."}
30
  messages = []
31
+ for turn in history:
32
+ user_msg, assistant_msg = turn
33
  messages.append({"role": "user", "content": user_msg})
34
  messages.append({"role": "assistant", "content": assistant_msg})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # 添加当前用户消息
37
+ messages.append({"role": "user", "content": message})
38
+
39
+ # 2. 构建API请求体
40
+ # 我们启用流式响应 (stream=True)
41
+ payload = {
42
+ "inputs": messages,
43
+ "parameters": {
44
+ "max_new_tokens": 2048, # 根据需要调整
45
+ "temperature": 0.7,
46
+ "top_p": 0.95,
47
+ "repetition_penalty": 1.1,
48
+ "return_full_text": False,
49
+ },
50
+ "stream": True
51
+ }
52
+
53
+ # 3. 发送流式请求并处理响应
54
+ full_response = ""
55
  try:
56
+ # 使用 requests 发送POST请求,并设置 stream=True
57
+ with requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=120) as response:
58
+ # 检查HTTP响应状态码
59
+ response.raise_for_status()
60
+
61
+ # 逐行读取流式响��
62
+ for line in response.iter_lines():
63
+ if line:
64
+ # 流式响应通常以 "data:" 开头,后跟一个JSON对象
65
+ decoded_line = line.decode('utf-8')
66
+ if decoded_line.startswith("data:"):
67
+ try:
68
+ # 解析JSON
69
+ json_data = json.loads(decoded_line[5:])
70
+ # 提取token文本
71
+ token = json_data.get("token", {}).get("text", "")
72
+ if token:
73
+ full_response += token
74
+ yield full_response
75
+ except json.JSONDecodeError:
76
+ # 忽略无法解析的行
77
+ continue
78
+
79
+ except requests.exceptions.RequestException as e:
80
+ print(f"API请求错误: {e}")
81
+ yield f"抱歉,与模型API通信时发生错误: {e}"
82
  except Exception as e:
83
+ print(f"发生未知错误: {e}")
84
+ yield f"抱歉,发生了一个未知错误: {e}"
85
+
86
+ # --- 创建并启动Gradio界面 ---
87
+
88
+ # 使用gr.ChatInterface,它为聊天机器人提供了完整的UI
89
+ # fn=predict 指定了处理逻辑的函数
90
+ # streaming=True 告诉Gradio我们的函数是流式的(使用yield)
91
+ # Gradio 4.44.1中,ChatInterface会自动处理stream参数,我们只需确保函数是生成器
92
+ demo = gr.ChatInterface(
93
+ fn=predict,
94
+ title="小Q老师 - 基础问答",
95
+ description="与 badanwang/teacher_basic_qwen3-0.6b 模型进行流式对话。直接输入问题开始。",
96
+ examples=[["你好"], ["请用python写一个快速排序算法"], ["给我讲个笑话吧"]],
97
+ cache_examples=False,
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ # demo.launch(share=True) # 如果在本地运行并需要分享链接
102
+ demo.launch() # 在Hugging Face Spaces上运行时使用