jinv2 commited on
Commit
5d4acdc
·
verified ·
1 Parent(s): 4bdb349

Create app.py

Browse files

添加 OPT-125m Gradio 应用

Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
4
+ import time
5
+
6
+ # --- 配置 ---
7
+ MODEL_ID = "jinv2/opt125m-wikitext2-finetuned"
8
+ TASK = "text-generation"
9
+
10
+ # --- 设备选择 ---
11
+ # 优先使用 GPU (如果 Space 配置了)
12
+ device = 0 if torch.cuda.is_available() else -1
13
+ device_name = "GPU" if device == 0 else "CPU"
14
+ print(f"使用设备: {device_name}")
15
+
16
+ # --- 加载模型 Pipeline ---
17
+ # 使用 pipeline 简化文本生成任务
18
+ print(f"开始加载模型: {MODEL_ID}...")
19
+ try:
20
+ # 对于 OPT 模型,通常不需要 trust_remote_code=True
21
+ # torch_dtype 设为 'auto' 让 transformers 自动选择最佳精度
22
+ pipe = pipeline(
23
+ TASK,
24
+ model=MODEL_ID,
25
+ torch_dtype='auto', # 自动选择精度 (float32 on CPU, float16/bfloat16 on GPU if supported)
26
+ device=device
27
+ )
28
+ print("模型加载成功。")
29
+ # 获取模型实际加载的数据类型
30
+ if hasattr(pipe.model, 'dtype'):
31
+ loaded_dtype = pipe.model.dtype
32
+ print(f"模型加载使用的数据类型: {loaded_dtype}")
33
+ else:
34
+ print("无法自动检测模型加载的数据类型,可能使用默认值。")
35
+
36
+ except Exception as e:
37
+ print(f"加载模型时出错: {e}")
38
+ raise gr.Error(f"加载模型 '{MODEL_ID}' 失败。错误: {e}。请检查 Space 日志。")
39
+
40
+ # --- 文本生成函数 ---
41
+ def generate_text(prompt, max_length, temperature, top_p, repetition_penalty):
42
+ """使用加载的 pipeline 生成文本"""
43
+ if not prompt:
44
+ return "请输入起始文本 (prompt)。"
45
+
46
+ print(f"\n收到提示词: '{prompt}'")
47
+ print(f"生成参数: 最大长度={max_length}, 温度={temperature}, Top-p={top_p}, 重复惩罚={repetition_penalty}")
48
+
49
+ # 注意:max_length 通常包含 prompt 的长度。
50
+ # 我们希望生成 max_new_tokens,所以总长度是 prompt 长度 + max_new_tokens
51
+ # 但 text-generation pipeline 的 max_length 参数是 *总* 长度。
52
+ # 为简单起见,我们直接使用 max_length 作为总长度限制,用户输入的 prompt 会被计算在内。
53
+ # 或者,我们可以计算 prompt 的 token 数量并加上期望的新 token 数。
54
+ # 这里我们采用更简单的 max_length 方法。
55
+
56
+ start_time = time.time()
57
+ try:
58
+ # OPT 模型通常用于文本续写,不需要复杂的聊天模板
59
+ outputs = pipe(
60
+ prompt,
61
+ max_length=max_length, # 这是生成的总文本长度,包括 prompt
62
+ do_sample=True if temperature > 0 else False, # 仅当 temperature > 0 时采样
63
+ temperature=max(temperature, 1e-6), # Temperature 不能为 0 或负数
64
+ top_p=top_p,
65
+ repetition_penalty=repetition_penalty,
66
+ num_return_sequences=1,
67
+ pad_token_id=pipe.tokenizer.eos_token_id # 避免填充警告
68
+ )
69
+ generated_text = outputs[0]['generated_text']
70
+
71
+ # pipeline 输出通常包含原始提示,我们只返回生成的部分
72
+ # (如果需要完整文本,可以直接返回 generated_text)
73
+ response = generated_text[len(prompt):].strip()
74
+
75
+ end_time = time.time()
76
+ duration = end_time - start_time
77
+ print(f"生成完成。原始输出长度: {len(generated_text)}, 提取的续写部分: {response}")
78
+ print(f"生成耗时: {duration:.2f} 秒")
79
+
80
+ # 如果模型有时不生成任何新内容,返回提示信息
81
+ if not response and len(generated_text) <= len(prompt):
82
+ return "(模型没有生成新的文本,可能需要调整参数或 prompt)"
83
+ return response
84
+
85
+ except Exception as e:
86
+ print(f"生成过程中发生错误: {e}")
87
+ import traceback
88
+ traceback.print_exc()
89
+ return f"生成过程中发生错误: {e}"
90
+
91
+ # --- 创建 Gradio 界面 ---
92
+ with gr.Blocks(theme=gr.themes.Soft(), title=f"测试 {MODEL_ID}") as demo:
93
+ gr.Markdown(f"""
94
+ # 测试文本生成模型: `{MODEL_ID}`
95
+ 输入一段起始文本 (prompt),模型将尝试续写它。
96
+ **注意:** 模型运行在 **{device_name}** 上。
97
+ """)
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=2):
101
+ prompt_input = gr.Textbox(
102
+ label="输入起始文本 (Prompt)",
103
+ lines=5,
104
+ placeholder="例如:从前有一只勇敢的小兔子,它梦想着..."
105
+ )
106
+ with gr.Accordion("高级生成选项", open=False):
107
+ max_length_slider = gr.Slider(
108
+ minimum=20,
109
+ maximum=512, # OPT-125m 的标准上下文长度通常是 2048,但设置低一些以防内存问题和过长生成
110
+ value=100, # 默认生成较短的续写
111
+ step=10,
112
+ label="最大总长度 (Max Length)",
113
+ info="生成的文本(包括提示)的最大令牌数。"
114
+ )
115
+ temperature_slider = gr.Slider(
116
+ minimum=0.1,
117
+ maximum=2.0,
118
+ value=0.7,
119
+ step=0.05,
120
+ label="温度 (Temperature)",
121
+ info="控制随机性。>1 更随机, <1 更确定。0 表示贪婪解码。"
122
+ )
123
+ top_p_slider = gr.Slider(
124
+ minimum=0.1,
125
+ maximum=1.0,
126
+ value=0.9,
127
+ step=0.05,
128
+ label="Top-p (Nucleus Sampling)",
129
+ info="累积概率阈值,用于筛选下一个词的候选。仅在 temperature > 0 时有效。"
130
+ )
131
+ repetition_penalty_slider = gr.Slider(
132
+ minimum=1.0,
133
+ maximum=2.0,
134
+ value=1.1,
135
+ step=0.1,
136
+ label="重复惩罚 (Repetition Penalty)",
137
+ info="大于 1 可减少重复。设为 1.0 则禁用。"
138
+ )
139
+ submit_button = gr.Button("生成续写", variant="primary")
140
+
141
+ with gr.Column(scale=3):
142
+ output_text = gr.Textbox(
143
+ label="模型续写内容 (Generated Text)",
144
+ lines=15,
145
+ interactive=False
146
+ )
147
+
148
+ gr.Examples(
149
+ examples=[
150
+ ["人工智能的未来是", 150, 0.8, 0.9, 1.1],
151
+ ["今天天气真不错,阳光明媚,", 80, 0.7, 0.95, 1.0],
152
+ ["The quick brown fox jumps over the", 50, 0.5, 0.9, 1.2],
153
+ ],
154
+ inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider, repetition_penalty_slider],
155
+ outputs=output_text,
156
+ fn=generate_text,
157
+ cache_examples=False,
158
+ label="示例"
159
+ )
160
+
161
+ submit_button.click(
162
+ fn=generate_text,
163
+ inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider, repetition_penalty_slider],
164
+ outputs=output_text,
165
+ api_name="generate"
166
+ )
167
+
168
+ # 启动 Gradio 应用
169
+ demo.launch()