xingyu1996 commited on
Commit
381a595
·
verified ·
1 Parent(s): f316b74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -69
app.py CHANGED
@@ -1,78 +1,78 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient, hf_hub_download
3
- from transformers import AutoTokenizer
4
- import os
5
- import torch
6
 
7
- # --- 设置模型 ID ---
8
- model_id = "xingyu1996/tiger-gpt2"
9
- client = InferenceClient(model_id)
10
 
11
- # --- 关键变化: 直接加载与训练时相同的分词器 ---
12
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
13
 
14
- def respond(
15
- message,
16
- history: list[tuple[str, str]],
17
- max_tokens,
18
- temperature,
19
- top_p,
20
- ):
21
- prompt = message
22
- response_ids = []
23
- response_text = ""
24
 
25
- # --- 参数准备 ---
26
- generation_args = {
27
- "max_new_tokens": max_tokens,
28
- "stream": True,
29
- "details": True, # 让 API 返回 token ID (重要变化)
30
- }
31
 
32
- if temperature is not None and temperature > 0:
33
- generation_args["temperature"] = temperature
34
- if top_p is not None and top_p < 1.0:
35
- generation_args["top_p"] = top_p
36
 
37
- try:
38
- # --- 调用 API, 获取 token IDs ---
39
- for output in client.text_generation(prompt, **generation_args):
40
- if hasattr(output, 'token'): # 流式生成时的输出
41
- # 这里 output.token 是一个字典,包含 id 和 text
42
- token_id = output.token.id
43
- response_ids.append(token_id)
44
-
45
- # 使用我们自己的 tokenizer 来解码
46
- current_text = tokenizer.decode(response_ids, skip_special_tokens=True)
47
- response_text = current_text
48
-
49
- yield response_text
50
- elif hasattr(output, 'generated_text'): # 非流式生成时的最终输出
51
- # 如果直接返回了完整文本,就用它
52
- response_text = output.generated_text
53
- yield response_text
54
- except Exception as e:
55
- print(f"推理时发生错误: {type(e).__name__} - {e}")
56
- yield f"抱歉,推理时遇到错误: {type(e).__name__} - {str(e)}"
57
 
58
 
59
- # 其他 Gradio 界面代码不变
60
- demo = gr.ChatInterface(
61
- respond,
62
- additional_inputs=[
63
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
64
- gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
65
- gr.Slider(
66
- minimum=0.1,
67
- maximum=1.0,
68
- value=0.95,
69
- step=0.05,
70
- label="Top-p (nucleus sampling)",
71
- ),
72
- ],
73
- title=f"推理测试: {model_id}",
74
- description="输入中文文本,模型将进行补全。"
75
- )
76
 
77
- if __name__ == "__main__":
78
- demo.launch()
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient, hf_hub_download
3
+ from transformers import AutoTokenizer
4
+ import os
5
+ import torch
6
 
7
+ # --- 设置模型 ID ---
8
+ model_id = "xingyu1996/tiger-gpt2"
9
+ client = InferenceClient(model_id)
10
 
11
+ # --- 关键变化: 直接加载与训练时相同的分词器 ---
12
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
13
 
14
+ def respond(
15
+ message,
16
+ history: list[tuple[str, str]],
17
+ max_tokens,
18
+ temperature,
19
+ top_p,
20
+ ):
21
+ prompt = message
22
+ response_ids = []
23
+ response_text = ""
24
 
25
+ # --- 参数准备 ---
26
+ generation_args = {
27
+ "max_new_tokens": max_tokens,
28
+ "stream": True,
29
+ "details": True, # 让 API 返回 token ID (重要变化)
30
+ }
31
 
32
+ if temperature is not None and temperature > 0:
33
+ generation_args["temperature"] = temperature
34
+ if top_p is not None and top_p < 1.0:
35
+ generation_args["top_p"] = top_p
36
 
37
+ try:
38
+ # --- 调用 API, 获取 token IDs ---
39
+ for output in client.text_generation(prompt, **generation_args):
40
+ if hasattr(output, 'token'): # 流式生成时的输出
41
+ # 这里 output.token 是一个字典,包含 id 和 text
42
+ token_id = output.token.id
43
+ response_ids.append(token_id)
44
+
45
+ # 使用我们自己的 tokenizer 来解码
46
+ current_text = tokenizer.decode(response_ids, skip_special_tokens=True)
47
+ response_text = current_text
48
+
49
+ yield response_text
50
+ elif hasattr(output, 'generated_text'): # 非流式生成时的最终输出
51
+ # 如果直接返回了完整文本,就用它
52
+ response_text = output.generated_text
53
+ yield response_text
54
+ except Exception as e:
55
+ print(f"推理时发生错误: {type(e).__name__} - {e}")
56
+ yield f"抱歉,推理时遇到错误: {type(e).__name__} - {str(e)}"
57
 
58
 
59
+ # 其他 Gradio 界面代码不变
60
+ demo = gr.ChatInterface(
61
+ respond,
62
+ additional_inputs=[
63
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
64
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
65
+ gr.Slider(
66
+ minimum=0.1,
67
+ maximum=1.0,
68
+ value=0.95,
69
+ step=0.05,
70
+ label="Top-p (nucleus sampling)",
71
+ ),
72
+ ],
73
+ title=f"推理测试: {model_id}",
74
+ description="输入中文文本,模型将进行补全。"
75
+ )
76
 
77
+ if __name__ == "__main__":
78
+ demo.launch()