Spaces:
Build error
Build error
| import torch | |
| import csv, os, sys | |
| import argparse | |
| from keybert import KeyBERT | |
| from sentence_transformers import SentenceTransformer | |
| class KeyWordExtractor(): | |
| def __init__(self): | |
| KWE_PRETRAINED = 'medmediani/Arabic-KW-Mdel' | |
| self.SEQ_LENGTH = 512 | |
| self.MAX_KW_NGS=3 | |
| self.NKW=3 | |
| #self.device = torch.device('cpu') | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| sentence_model = SentenceTransformer(KWE_PRETRAINED) | |
| sentence_model.to(self.device) | |
| self.kw_model = KeyBERT(model=sentence_model) | |
| #self.kw_model.to(self.device) | |
| def _extract_by_paragraph(self, ctxt, nkws=None, max_kw_ngs=None): | |
| paragraphs=map(str.strip,ctxt.split("\n")) | |
| kws=[] | |
| for paragraph in paragraphs: | |
| if paragraph: | |
| kws.extend(self.kw_model.extract_keywords(paragraph, keyphrase_ngram_range=(1, max_kw_ngs), | |
| top_n=nkws, | |
| #use_maxsum=True,nr_candidates=20, top_n=5, | |
| #use_mmr=True, | |
| diversity=0.8, | |
| stop_words=None) | |
| ) | |
| print("KWS=",kws,file=sys.stderr) | |
| kws.sort(key=lambda x: x[1],reverse=True) | |
| ukws=set() | |
| for kw,_ in kws: | |
| if len(ukws)>=nkws: | |
| return ukws | |
| ukws.add(kw) | |
| return ukws | |
| def extract(self, ctxt, nkws=None, max_kw_ngs=None): | |
| nkws= nkws if nkws is not None else self.NKW | |
| max_kw_ngs=max_kw_ngs if max_kw_ngs is not None else self.MAX_KW_NGS | |
| #Since we are taking only 512 tokens, let's do by paragraph | |
| kw=self._extract_by_paragraph(ctxt,nkws,max_kw_ngs) | |
| return ", ".join(kw) | |
| return ", ".join(w for w,_ in kw) | |