TLH01 commited on
Commit
9987eb0
·
verified ·
1 Parent(s): 9b8533f

Rename app.py to t5_small.py

Browse files
Files changed (2) hide show
  1. app.py +0 -0
  2. t5_small.py +109 -0
app.py DELETED
File without changes
t5_small.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import time
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+ # 设置页面配置
7
+ st.set_page_config(
8
+ page_title="文本关键点提取工具",
9
+ page_icon="📝",
10
+ layout="wide"
11
+ )
12
+
13
+ # 标题和介绍
14
+ st.title("文本关键点提取工具")
15
+ st.markdown("基于t5-small模型,从文本中提取关键点")
16
+
17
+ # 定义模型
18
+ model_list = {
19
+ "t5-small": "keypoint_T5-Small"
20
+ }
21
+
22
+ # 缓存模型加载(避免重复加载)
23
+ @st.cache_resource
24
+ def load_model(model_name):
25
+ st.info(f"正在加载模型: {model_name}")
26
+ start_time = time.time()
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
30
+
31
+ # 判断是否有GPU可用
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model = model.to(device)
34
+
35
+ elapsed = time.time() - start_time
36
+ st.success(f"✅ 模型加载完成: {model_name},耗时 {elapsed:.2f} 秒")
37
+
38
+ return model, tokenizer, device
39
+
40
+ # 生成关键点的函数
41
+ def generate_keypoints(model, tokenizer, device, text, max_new_tokens=64):
42
+ if not text.strip():
43
+ return "请输入文本内容"
44
+
45
+ # T5模型的特定提示
46
+ prompt = f"summarize: {text}"
47
+
48
+ # 编码输入文本
49
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
50
+
51
+ # 生成关键点
52
+ with torch.no_grad():
53
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
54
+
55
+ # 解码输出
56
+ keypoint = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
57
+
58
+ # 后处理:规范化"no key point"输出
59
+ if keypoint.lower() in ["none", "no keypoint", "no key point", "n/a", "na", "", "nothing"]:
60
+ keypoint = "未提取到关键点"
61
+
62
+ return keypoint
63
+
64
+ # 主界面
65
+ def main():
66
+ # 侧边栏
67
+ with st.sidebar:
68
+ st.header("模型配置")
69
+ max_new_tokens = st.slider("最大生成长度", min_value=16, max_value=256, value=64, step=16)
70
+
71
+ # 加载模型
72
+ model, tokenizer, device = load_model(list(model_list.keys())[0])
73
+
74
+ # 主内容区
75
+ col1, col2 = st.columns([1, 1])
76
+
77
+ with col1:
78
+ st.subheader("输入文本")
79
+ user_text = st.text_area(
80
+ "请输入需要提取关键点的文本",
81
+ height=300,
82
+ placeholder="在此粘贴文本内容..."
83
+ )
84
+
85
+ if st.button("提取关键点"):
86
+ if model and tokenizer and device:
87
+ with st.spinner("正在提取关键点..."):
88
+ start_time = time.time()
89
+ result = generate_keypoints(model, tokenizer, device, user_text, max_new_tokens)
90
+ elapsed = time.time() - start_time
91
+
92
+ st.session_state["result"] = result
93
+ st.session_state["time"] = elapsed
94
+
95
+ st.success(f"✅ 关键点提取完成,耗时 {elapsed:.2f} 秒")
96
+ else:
97
+ st.warning("请先确保模型加载成功")
98
+
99
+ with col2:
100
+ st.subheader("提取结果")
101
+ if "result" in st.session_state:
102
+ st.markdown(f"**{list(model_list.values())[0]}:**")
103
+ st.info(st.session_state["result"])
104
+ st.caption(f"生成耗时: {st.session_state['time']:.2f} 秒")
105
+ else:
106
+ st.info("请输入文本并点击提取按钮")
107
+
108
+ if __name__ == "__main__":
109
+ main()