han145 commited on
Commit
4a5c42e
·
verified ·
1 Parent(s): e170451

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -40
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  # 全局变量,避免重复加载
6
  model = None
@@ -25,36 +26,13 @@ def load_model():
25
  except Exception as e:
26
  print(f"模型加载失败: {e}")
27
 
28
- def chat_with_deepseek(message, history):
29
- """与DeepSeek模型聊天 - 修正版"""
30
- global model, tokenizer
31
-
32
  if model is None:
33
  load_model()
34
 
35
- # 构建对话历史
36
- conversation = []
37
- for user_msg, assistant_msg in history:
38
- conversation.append({"role": "user", "content": user_msg})
39
- conversation.append({"role": "assistant", "content": assistant_msg})
40
- conversation.append({"role": "user", "content": message})
41
-
42
- # 使用tokenizer的apply_chat_template方法(如果支持)
43
- try:
44
- prompt = tokenizer.apply_chat_template(
45
- conversation,
46
- tokenize=False,
47
- add_generation_prompt=True
48
- )
49
- except:
50
- # 如果不支持apply_chat_template,使用简单格式
51
- prompt = ""
52
- for msg in conversation:
53
- if msg["role"] == "user":
54
- prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n"
55
- else:
56
- prompt += f"<|im_start|>assistant\n{msg['content']}<|im_end|>\n"
57
- prompt += "<|im_start|>assistant\n"
58
 
59
  # 编码输入
60
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
@@ -75,7 +53,7 @@ def chat_with_deepseek(message, history):
75
  # 解码回复
76
  response = tokenizer.decode(outputs[0], skip_special_tokens=False)
77
 
78
- # 关键修正:提取助理的回复部分
79
  if "<|im_start|>assistant" in response:
80
  # 找到最后一个assistant标记开始的位置
81
  assistant_start = response.rfind("<|im_start|>assistant")
@@ -93,20 +71,51 @@ def chat_with_deepseek(message, history):
93
  # 如果找不到标记,返回整个响应(去除提示部分)
94
  generated_text = response.replace(prompt, "").strip()
95
 
96
- # 关键修改:直接返回字符串,而不是OpenAI格式的字典
97
- return generated_text
98
-
99
- # 预先加载模型(可选,会延长启动时间但减少第一次请求的延迟)
100
- # load_model()
 
 
 
 
101
 
102
  # 创建Gradio界面
103
- demo = gr.ChatInterface(
104
- fn=chat_with_deepseek,
105
- title="DeepSeek-R1 聊天助手",
106
- description="基于DeepSeek-R1-Distill-Qwen-1.5B的聊天机器人",
107
- examples=["你好!", "请介绍一下你自己", "写一个Python函数计算斐波那契数列"],
108
- cache_examples=False # 禁用缓存,避免格式问题
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
  demo.launch(
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import json
5
 
6
  # 全局变量,避免重复加载
7
  model = None
 
26
  except Exception as e:
27
  print(f"模型加载失败: {e}")
28
 
29
+ def predict_api(message):
30
+ """API专用预测函数"""
 
 
31
  if model is None:
32
  load_model()
33
 
34
+ # 构建对话提示
35
+ prompt = f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # 编码输入
38
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
 
53
  # 解码回复
54
  response = tokenizer.decode(outputs[0], skip_special_tokens=False)
55
 
56
+ # 提取助理的回复部分
57
  if "<|im_start|>assistant" in response:
58
  # 找到最后一个assistant标记开始的位置
59
  assistant_start = response.rfind("<|im_start|>assistant")
 
71
  # 如果找不到标记,返回整个响应(去除提示部分)
72
  generated_text = response.replace(prompt, "").strip()
73
 
74
+ # 返回OpenAI兼容格式
75
+ return {
76
+ "choices": [{
77
+ "message": {
78
+ "role": "assistant",
79
+ "content": generated_text
80
+ }
81
+ }]
82
+ }
83
 
84
  # 创建Gradio界面
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("# DeepSeek-R1 API 服务")
87
+
88
+ # 聊天界面
89
+ chatbot = gr.Chatbot(label="DeepSeek-R1")
90
+ msg = gr.Textbox(label="输入消息")
91
+ clear = gr.Button("清除")
92
+
93
+ def respond(message, chat_history):
94
+ """处理聊天请求"""
95
+ # 调用预测函数
96
+ response = predict_api(message)
97
+ # 提取内容
98
+ bot_message = response["choices"][0]["message"]["content"]
99
+ # 更新聊天历史
100
+ chat_history.append((message, bot_message))
101
+ return "", chat_history
102
+
103
+ # 设置界面交互
104
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
105
+ clear.click(lambda: None, None, chatbot, queue=False)
106
+
107
+ # 添加API端点
108
+ gr.Interface(
109
+ fn=predict_api,
110
+ inputs=gr.Textbox(label="输入消息", lines=2),
111
+ outputs=gr.JSON(label="API响应"),
112
+ title="OpenAI兼容API",
113
+ description="使用此端点进行API调用",
114
+ api_name="predict"
115
+ )
116
+
117
+ # 预先加载模型(可选)
118
+ # load_model()
119
 
120
  if __name__ == "__main__":
121
  demo.launch(