tosei0000 commited on
Commit
78f4d36
·
verified ·
1 Parent(s): 9d7d5b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -30
app.py CHANGED
@@ -1,41 +1,86 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
 
4
 
5
- # 修改为你上传的模型文件夹路径
6
- model_path = "tosei0000/code-AI"
7
 
8
- # 加载模型和分词器
9
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_path,
12
- trust_remote_code=True,
13
- torch_dtype=torch.float32 # 若使用 GPU,可改为 torch.float16
14
- )
15
  model.eval()
16
 
17
- def chat(prompt, max_new_tokens=512):
18
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with torch.no_grad():
20
  outputs = model.generate(
21
  **inputs,
22
- max_new_tokens=max_new_tokens,
23
- do_sample=True,
24
  top_p=0.95,
25
- temperature=0.8,
26
- pad_token_id=tokenizer.eos_token_id,
27
  )
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- return response[len(prompt):] # 返回去除原始prompt后的回答部分
30
-
31
- # 创建 Gradio 界面
32
- gr.Interface(
33
- fn=chat,
34
- inputs=[
35
- gr.Textbox(label="你的提问", lines=3, placeholder="请输入 prompt..."),
36
- gr.Slider(128, 1024, step=64, value=512, label="最大生成长度")
37
- ],
38
- outputs="text",
39
- title="🧠 DeepSeek-R1 Chat Demo",
40
- description="使用你本地上传的 DeepSeek-R1 模型运行的聊天机器人。"
41
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
+ # 模型路径(本地或 huggingface repo)
6
+ model_path = "tosei0000/code-AI" # ← 修改为你的模型目录或 huggingface repo_id
7
 
8
+ # 加载 tokenizer 和模型
9
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
10
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float32)
 
 
 
 
11
  model.eval()
12
 
13
+ # 系统人格设定,可修改为你喜欢的角色
14
+ SYSTEM_PROMPT = (
15
+ "你是一个善良、聪明、幽默的 AI 编程助手,能写各种 Python/HTML/JavaScript 代码,"
16
+ "也能像朋友一样聊天。请保持耐心、有趣,尽可能详细地回答问题。\n"
17
+ )
18
+
19
+ # 对话历史记录(最多保存 5 轮)
20
+ chat_history = []
21
+
22
+ # 生成回复函数
23
+ def generate_reply(user_input, chat_mode, max_tokens=512, temperature=0.7):
24
+ global chat_history
25
+
26
+ # 更新历史上下文
27
+ chat_history.append(f"用户: {user_input}")
28
+ if len(chat_history) > 5:
29
+ chat_history = chat_history[-5:] # 只保留最近 5 条
30
+
31
+ # 拼接 prompt
32
+ if chat_mode == "代码生成":
33
+ prompt = f"{SYSTEM_PROMPT}\n请根据以下需求生成代码:\n{user_input}\n"
34
+ else: # 聊天模式
35
+ prompt = SYSTEM_PROMPT + "\n" + "\n".join(chat_history) + "\n助手:"
36
+
37
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
38
  with torch.no_grad():
39
  outputs = model.generate(
40
  **inputs,
41
+ max_new_tokens=max_tokens,
42
+ temperature=temperature,
43
  top_p=0.95,
44
+ do_sample=True,
45
+ eos_token_id=tokenizer.eos_token_id,
46
  )
47
+
48
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+
50
+ # 提取最后一句助手的回复
51
+ reply = decoded_output.split("助手:")[-1].strip()
52
+
53
+ # 保存助手回复到历史中
54
+ chat_history.append(f"助手: {reply}")
55
+
56
+ return reply
57
+
58
+ # 重置历史按钮
59
+ def reset_memory():
60
+ global chat_history
61
+ chat_history = []
62
+ return "记忆已重置。"
63
+
64
+ # Gradio 界面
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("# 🤖 智能代码助理 + 聊天机器人")
67
+ gr.Markdown("支持代码生成与聊天功能,可记忆上下文,具备人格设定!")
68
+
69
+ with gr.Row():
70
+ chat_mode = gr.Radio(["聊天", "代码生成"], value="代码生成", label="对话模式")
71
+ reset_btn = gr.Button("🧹 重置记忆")
72
+
73
+ user_input = gr.Textbox(label="你的输入", lines=6, placeholder="输入代码需求或聊天内容...")
74
+ max_tokens = gr.Slider(50, 1024, value=512, step=10, label="最大生成长度")
75
+ temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="生成多样性(temperature)")
76
+
77
+ output = gr.Textbox(label="AI 回复", lines=10)
78
+
79
+ submit_btn = gr.Button("✨ 生成")
80
+
81
+ submit_btn.click(fn=generate_reply, inputs=[user_input, chat_mode, max_tokens, temperature], outputs=output)
82
+ reset_btn.click(fn=reset_memory, outputs=output)
83
+
84
+ # 启动服务
85
+ if __name__ == "__main__":
86
+ demo.launch()