Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import (AutoTokenizer, AutoModel, LEDTokenizer, LEDForConditionalGeneration) | |
| from keybert import KeyBERT | |
| import spacy, spacy.cli | |
| import gradio as gr | |
| from sklearn.cluster import AgglomerativeClustering | |
| import re | |
| import pandas as pd | |
| import sys | |
| import os | |
| def load_models(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pubmed_tok = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") | |
| pubmed_model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", use_safetensors=True).to(device) | |
| kb_model = KeyBERT(model=pubmed_model) | |
| led_tok = LEDTokenizer.from_pretrained("dancessa/led_pubmed_summarization") | |
| led_mod = LEDForConditionalGeneration.from_pretrained("dancessa/led_pubmed_summarization", use_safetensors=True).to(device) | |
| try: | |
| nlp = spacy.load("en_core_sci_sm") | |
| except OSError: | |
| spacy.cli.download("en_core_sci_sm") | |
| nlp = spacy.load("en_core_sci_sm") | |
| return { | |
| "led": { | |
| "tokenizer": led_tok, | |
| "model": led_mod | |
| }, | |
| "pubmed": { | |
| "tokenizer": pubmed_tok, | |
| "model": pubmed_model | |
| }, | |
| "keybert": kb_model, | |
| "spacy": nlp | |
| } | |
| models = load_models() | |
| nlp = models["spacy"] | |
| kw_model = models["keybert"] | |
| pubmed_tokenizer= models["pubmed"]["tokenizer"] | |
| pubmed_model = models["pubmed"]["model"] | |
| led_tokenizer = models["led"]["tokenizer"] | |
| led_model = models["led"]["model"] | |
| def generate_summary(medical_text): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = led_tokenizer( | |
| medical_text, | |
| max_length=4096, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = led_model.generate( | |
| input_ids=inputs["input_ids"], | |
| max_length=256, | |
| num_beams=4, | |
| early_stopping=True, | |
| length_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=1.5 | |
| ) | |
| generated_summary = led_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| return format_medical_summary(generated_summary) | |
| def format_medical_summary(generated_text): | |
| clean_text = generated_text.replace('</s>', '').replace('<s>', '').strip() | |
| results_section = '' | |
| conclusions_section = '' | |
| if '<results>' in clean_text: | |
| results_part = clean_text.split('<results>')[1] | |
| results_section = results_part.split('<conclusions>')[0].strip() | |
| if '<conclusions>' in clean_text: | |
| conclusions_part = clean_text.split('<conclusions>')[1] | |
| conclusions_section = conclusions_part.split('<dig>')[0].strip() | |
| formatted_output = "" | |
| if results_section: | |
| results_section = results_section[0].upper() + results_section[1:] | |
| formatted_output += "RESULTS:\n" + results_section + "\n\n" | |
| if conclusions_section: | |
| conclusions_section = conclusions_section[0].upper() + conclusions_section[1:] | |
| formatted_output += "CONCLUSIONS:\n" + conclusions_section | |
| return formatted_output.strip() | |
| def chunk_text(text, max_tokens=512, stride=128): | |
| sentences = [sent.text for sent in nlp(text).sents] | |
| current_chunk = [] | |
| current_length = 0 | |
| chunks = [] | |
| for sentence in sentences: | |
| sent_tokens = pubmed_tokenizer.tokenize(sentence) | |
| if current_length + len(sent_tokens) > max_tokens: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = current_chunk[-stride // 2:] | |
| current_length = len(current_chunk) | |
| current_chunk.append(sentence) | |
| current_length += len(sent_tokens) | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| return chunks | |
| def extract_candidates(text): | |
| doc = nlp(text) | |
| noun_chunks = { | |
| " ".join(tok.text for tok in chunk).lower() | |
| for chunk in doc.noun_chunks if 1 <= len(chunk) <= 5 | |
| } | |
| extras = { | |
| f"{doc[i].text} {doc[i + 1].text}".lower() | |
| for i in range(len(doc) - 1) | |
| if doc[i].pos_ in {"ADJ", "NOUN"} and doc[i + 1].pos_ == "NOUN" | |
| } | |
| abbrs = {t.text for t in doc if t.is_upper and 2 <= len(t.text) <= 6} | |
| return list(noun_chunks | extras | abbrs) | |
| def extract_keyphrases(text, top_n=30): | |
| kw = kw_model.extract_keywords( | |
| text, | |
| candidates=extract_candidates(text), | |
| keyphrase_ngram_range=(2, 5), | |
| nr_candidates=80, | |
| use_mmr=True, diversity=0.85, top_n=top_n * 2 | |
| ) | |
| return kw[:top_n] | |
| def group_similar(keywords, thresh=0.85): | |
| phrases = [p for p, _ in keywords] | |
| emb = kw_model.model.embed(phrases) | |
| labels = AgglomerativeClustering(n_clusters=None, | |
| distance_threshold=1 - thresh, | |
| affinity="cosine", | |
| linkage="average").fit_predict(emb) | |
| best = {} | |
| for (ph, sc), lb in zip(keywords, labels): | |
| if lb not in best or sc > best[lb][1]: | |
| best[lb] = (ph, sc) | |
| return sorted(best.values(), key=lambda x: x[1], reverse=True) | |
| def extract_keyphrases_from_long_text(text): | |
| chunks = chunk_text(text) | |
| all_keywords = [] | |
| for chunk in chunks: | |
| keywords = extract_keyphrases(chunk) | |
| all_keywords.extend(keywords) | |
| unique_keywords = {} | |
| for phrase, score in all_keywords: | |
| if phrase not in unique_keywords or score > unique_keywords[phrase]: | |
| unique_keywords[phrase] = score | |
| sorted_keywords = sorted(unique_keywords.items(), key=lambda x: x[1], reverse=True) | |
| return sorted_keywords[:30] | |
| def format_keyterms_output(keywords): | |
| output = "KEY TERMS:\n" | |
| key_phrases = [f"- {phrase}" for phrase, score in keywords] | |
| output += "\n".join(key_phrases) | |
| return output | |
| def extract_references(text): | |
| patterns = [ | |
| r'(https?://[^\s<>"]+|www\.[^\s<>"]+)', # URL | |
| r'(arxiv:\d{4}\.\d{4,5})', # arXiv | |
| r'doi:\s*10\.\d{4,9}/[-._;()/:A-Za-z0-9]+', # DOI с префиксом | |
| r'10\.\d{4,9}/[-._;()/:A-Za-z0-9]+', # DOI без префикса | |
| r'PMID:\s*\d+', # PMID | |
| r'PMCID:\s*PMC\d+', # PMCID | |
| r'NCT\d{8}', # ClinicalTrials.gov | |
| r'ISBN(?:-13)?:?\s*(?:97[89][- ]?)?\d{1,5}[- ]?\d+[- ]?\d+[- ]?[\dX]', # ISBN | |
| r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', # Email | |
| ] | |
| results = [] | |
| for pattern in patterns: | |
| results.extend(re.findall(pattern, text, re.IGNORECASE)) | |
| return results | |
| def format_references_output(references): | |
| output = "REFERENCES:\n" | |
| references = [f"- {ref}" for ref in references] | |
| output += "\n".join(references) | |
| return output | |
| def gradio_summarize(medical_text): | |
| if len(medical_text) < 3000: | |
| return "Пожалуйста, введите медицинский текст на англисйком языке не менее 3000 символов" | |
| summary = generate_summary(medical_text) | |
| keywords = extract_keyphrases_from_long_text(medical_text) | |
| formatted_output = format_keyterms_output(keywords) | |
| references = extract_references(medical_text) | |
| if len(references) != 0: | |
| return summary + '\n\n' + formatted_output + '\n\n' + format_references_output(references) | |
| return summary + '\n\n' + formatted_output | |
| def main(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Автоматическое резюмирование медицинских публикаций") | |
| gr.Markdown("Введите медицинский текст (не менее 3000 символов). Модель сгенерирует краткое содержание.") | |
| input_text = gr.Textbox( | |
| lines=15, | |
| placeholder="Введите здесь медицинский текст...", | |
| label="Входной текст", | |
| elem_id="input-textbox" | |
| ) | |
| output_text = gr.Textbox( | |
| lines=20, | |
| label="Конспект", | |
| interactive=False, | |
| elem_id="output-textbox" | |
| ) | |
| summarize_btn = gr.Button("Сгенерировать конспект") | |
| summarize_btn.click(fn=gradio_summarize, inputs=input_text, outputs=output_text) | |
| demo.launch(share=True) | |
| if __name__ == "__main__": | |
| main() |