Spaces:
Build error
Build error
hellopahe
commited on
Commit
·
e0738a2
1
Parent(s):
534bdc5
add luotuo summary
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import numpy
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
|
| 5 |
-
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
|
| 6 |
from article_extractor.tokenizers_pegasus import PegasusTokenizer
|
| 7 |
from embed import Embed
|
| 8 |
|
|
@@ -12,6 +12,9 @@ from harvesttext import HarvestText
|
|
| 12 |
from sentence_transformers import SentenceTransformer, util
|
| 13 |
from LexRank import degree_centrality_scores
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class SummaryExtractor(object):
|
| 17 |
def __init__(self):
|
|
@@ -24,6 +27,39 @@ class SummaryExtractor(object):
|
|
| 24 |
print(content)
|
| 25 |
return str(self.text2text_genr(content, min_length=20, do_sample=False, num_return_sequences=3)[0]["generated_text"])
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
class LexRank(object):
|
| 28 |
def __init__(self):
|
| 29 |
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
|
|
@@ -56,35 +92,28 @@ class LexRank(object):
|
|
| 56 |
|
| 57 |
# ---===--- worker instances ---===---
|
| 58 |
t_randeng = SummaryExtractor()
|
|
|
|
|
|
|
| 59 |
embedder = Embed()
|
| 60 |
lex = LexRank()
|
| 61 |
|
| 62 |
|
| 63 |
def randeng_extract(content):
|
| 64 |
sentences = lex.find_central(content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
ptr = index - 1
|
| 72 |
-
break
|
| 73 |
-
if num < 0 and index == 0:
|
| 74 |
-
ptr = index
|
| 75 |
-
break
|
| 76 |
-
print(">>>")
|
| 77 |
-
for ele in sentences[:ptr]:
|
| 78 |
-
print(ele)
|
| 79 |
-
return t_randeng.extract("".join(sentences[:ptr]))
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def similarity_check(inputs: list):
|
| 83 |
-
doc_list = inputs[1].split("\n")
|
| 84 |
-
doc_list.append(inputs[0])
|
| 85 |
-
embedding_list = embedder.encode(doc_list)
|
| 86 |
-
scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist()
|
| 87 |
-
return numpy.array2string(scores, separator=',')
|
| 88 |
|
| 89 |
with gr.Blocks() as app:
|
| 90 |
gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
|
|
@@ -92,10 +121,14 @@ with gr.Blocks() as app:
|
|
| 92 |
# text_input = gr.Textbox()
|
| 93 |
# text_output = gr.Textbox()
|
| 94 |
# text_button = gr.Button("生成摘要")
|
| 95 |
-
with gr.Tab("Randeng-Pegasus-523M"):
|
| 96 |
text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
|
| 97 |
text_output_1 = gr.Textbox(label="摘要文本")
|
| 98 |
text_button_1 = gr.Button("生成摘要")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
with gr.Tab("相似度检���"):
|
| 100 |
with gr.Row():
|
| 101 |
text_input_query = gr.Textbox(label="查询文本")
|
|
@@ -103,7 +136,7 @@ with gr.Blocks() as app:
|
|
| 103 |
text_button_similarity = gr.Button("对比相似度")
|
| 104 |
text_output_similarity = gr.Textbox()
|
| 105 |
|
| 106 |
-
|
| 107 |
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
|
| 108 |
text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
|
| 109 |
|
|
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
|
| 5 |
+
from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline, AutoModel, AutoTokenizer
|
| 6 |
from article_extractor.tokenizers_pegasus import PegasusTokenizer
|
| 7 |
from embed import Embed
|
| 8 |
|
|
|
|
| 12 |
from sentence_transformers import SentenceTransformer, util
|
| 13 |
from LexRank import degree_centrality_scores
|
| 14 |
|
| 15 |
+
from luotuo_util import DeviceMap
|
| 16 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 17 |
+
|
| 18 |
|
| 19 |
class SummaryExtractor(object):
|
| 20 |
def __init__(self):
|
|
|
|
| 27 |
print(content)
|
| 28 |
return str(self.text2text_genr(content, min_length=20, do_sample=False, num_return_sequences=3)[0]["generated_text"])
|
| 29 |
|
| 30 |
+
class Tuoling_6B_extractor(object):
|
| 31 |
+
def __init__(self):
|
| 32 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
| 33 |
+
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
| 34 |
+
self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map=DeviceMap("ChatGLM").get())
|
| 35 |
+
|
| 36 |
+
# load fine-tuned pretrained model.
|
| 37 |
+
peft_path = "./luotuoC.pt"
|
| 38 |
+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=True, r=8, lora_alpha=32, lora_dropout=0.1)
|
| 39 |
+
self.model = get_peft_model(self.model, peft_config)
|
| 40 |
+
self.model.load_state_dict(torch.load(peft_path), strict=False)
|
| 41 |
+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def format_example(example: dict) -> dict:
|
| 45 |
+
context = f"Instruction: {example['instruction']}\n"
|
| 46 |
+
if example.get("input"):
|
| 47 |
+
context += f"Input: {example['input']}\n"
|
| 48 |
+
context += "Answer: "
|
| 49 |
+
target = example["output"]
|
| 50 |
+
return {"context": context, "target": target}
|
| 51 |
+
|
| 52 |
+
def extract(self, instruction: str, input=None) -> str:
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
feature = Tuoling_6B_extractor.format_example(
|
| 55 |
+
{"instruction": "请帮我总结以下内容", "output": "", "input": f"{instruction}"}
|
| 56 |
+
)
|
| 57 |
+
input_text = feature["context"]
|
| 58 |
+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
|
| 59 |
+
out = self.model.generate(input_ids=input_ids, max_length=2048, temperature=0)
|
| 60 |
+
answer = self.tokenizer.decode(out[0])
|
| 61 |
+
return answer.split('Answer:')[1]
|
| 62 |
+
|
| 63 |
class LexRank(object):
|
| 64 |
def __init__(self):
|
| 65 |
self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
|
|
|
|
| 92 |
|
| 93 |
# ---===--- worker instances ---===---
|
| 94 |
t_randeng = SummaryExtractor()
|
| 95 |
+
t_tuoling = Tuoling_6B_extractor()
|
| 96 |
+
|
| 97 |
embedder = Embed()
|
| 98 |
lex = LexRank()
|
| 99 |
|
| 100 |
|
| 101 |
def randeng_extract(content):
|
| 102 |
sentences = lex.find_central(content)
|
| 103 |
+
return str(list(t_randeng.extract(sentence) for sentence in sentences))
|
| 104 |
+
|
| 105 |
+
def tuoling_extract(content):
|
| 106 |
+
sentences = lex.find_central(content)
|
| 107 |
+
return str(list(t_tuoling.extract(sentence) for sentence in sentences))
|
| 108 |
+
|
| 109 |
+
def similarity_check(query, doc):
|
| 110 |
+
doc_list = doc.split("\n")
|
| 111 |
|
| 112 |
+
query_embedding = embedder.encode(query)
|
| 113 |
+
doc_embedding = embedder.encode(doc_list)
|
| 114 |
+
scores = (query_embedding @ tf.transpose(doc_embedding))[0].numpy().tolist()
|
| 115 |
+
# scores = list(util.cos_sim(embedding_list[-1], doc_embedding) for doc_embedding in embedding_list[:-1])
|
| 116 |
+
return str(scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
with gr.Blocks() as app:
|
| 119 |
gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
|
|
|
|
| 121 |
# text_input = gr.Textbox()
|
| 122 |
# text_output = gr.Textbox()
|
| 123 |
# text_button = gr.Button("生成摘要")
|
| 124 |
+
with gr.Tab("LexRank->Randeng-Pegasus-523M"):
|
| 125 |
text_input_1 = gr.Textbox(label="请输入长文本:", max_lines=1000)
|
| 126 |
text_output_1 = gr.Textbox(label="摘要文本")
|
| 127 |
text_button_1 = gr.Button("生成摘要")
|
| 128 |
+
with gr.Tab("LexRank->Tuoling-6B-chatGLM"):
|
| 129 |
+
text_input = gr.Textbox(label="请输入长文本:", max_lines=1000)
|
| 130 |
+
text_output = gr.Textbox(label="摘要文本")
|
| 131 |
+
text_button = gr.Button("生成摘要")
|
| 132 |
with gr.Tab("相似度检���"):
|
| 133 |
with gr.Row():
|
| 134 |
text_input_query = gr.Textbox(label="查询文本")
|
|
|
|
| 136 |
text_button_similarity = gr.Button("对比相似度")
|
| 137 |
text_output_similarity = gr.Textbox()
|
| 138 |
|
| 139 |
+
text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
|
| 140 |
text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
|
| 141 |
text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
|
| 142 |
|