Spaces:
Runtime error
Runtime error
updates
Browse files- lrt/clustering/clusters.py +30 -1
- lrt/lrt.py +26 -15
lrt/clustering/clusters.py
CHANGED
|
@@ -1,6 +1,32 @@
|
|
| 1 |
from typing import List, Iterable, Union
|
| 2 |
from pprint import pprint
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
class SingleCluster:
|
| 5 |
def __init__(self):
|
| 6 |
self.__container__ = []
|
|
@@ -12,7 +38,10 @@ class SingleCluster:
|
|
| 12 |
def elements(self) -> List:
|
| 13 |
return self.__container__
|
| 14 |
def get_keyphrases(self):
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
def add_keyphrase(self, keyphrase:Union[str,Iterable]):
|
| 17 |
if isinstance(keyphrase,str):
|
| 18 |
if keyphrase not in self.__keyphrases__.keys():
|
|
|
|
| 1 |
from typing import List, Iterable, Union
|
| 2 |
from pprint import pprint
|
| 3 |
|
| 4 |
+
class KeyphraseCount:
|
| 5 |
+
|
| 6 |
+
def __init__(self, keyphrase: str, count: int) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.keyphrase = keyphrase
|
| 9 |
+
self.count = count
|
| 10 |
+
|
| 11 |
+
@classmethod
|
| 12 |
+
def reduce(cls, kcs: list) :
|
| 13 |
+
'''
|
| 14 |
+
kcs: List[KeyphraseCount]
|
| 15 |
+
'''
|
| 16 |
+
keys = ''
|
| 17 |
+
count = 0
|
| 18 |
+
|
| 19 |
+
for i in range(len(kcs)-1):
|
| 20 |
+
kc = kcs[i]
|
| 21 |
+
keys += kc.keyphrase + '/'
|
| 22 |
+
count += kc.count
|
| 23 |
+
|
| 24 |
+
keys += kcs[-1].keyphrase
|
| 25 |
+
count += kcs[-1].count
|
| 26 |
+
return KeyphraseCount(keys, count)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
class SingleCluster:
|
| 31 |
def __init__(self):
|
| 32 |
self.__container__ = []
|
|
|
|
| 38 |
def elements(self) -> List:
|
| 39 |
return self.__container__
|
| 40 |
def get_keyphrases(self):
|
| 41 |
+
ret = []
|
| 42 |
+
for key, count in self.__keyphrases__.items():
|
| 43 |
+
ret.append(KeyphraseCount(key,count))
|
| 44 |
+
return ret
|
| 45 |
def add_keyphrase(self, keyphrase:Union[str,Iterable]):
|
| 46 |
if isinstance(keyphrase,str):
|
| 47 |
if keyphrase not in self.__keyphrases__.keys():
|
lrt/lrt.py
CHANGED
|
@@ -5,6 +5,8 @@ from .utils import UnionFind, ArticleList
|
|
| 5 |
from .academic_query import AcademicQuery
|
| 6 |
import streamlit as st
|
| 7 |
from tokenizers import Tokenizer
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class LiteratureResearchTool:
|
|
@@ -13,31 +15,40 @@ class LiteratureResearchTool:
|
|
| 13 |
self.cluster_pipeline = ClusterPipeline(cluster_config)
|
| 14 |
|
| 15 |
|
| 16 |
-
def __postprocess_clusters__(self, clusters: ClusterList) ->ClusterList:
|
| 17 |
'''
|
| 18 |
add top-5 keyphrases to each cluster
|
| 19 |
:param clusters:
|
| 20 |
:return: clusters
|
| 21 |
'''
|
| 22 |
-
def condition(x, y):
|
| 23 |
-
return td.ratcliff_obershelp(x, y) > 0.8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
def valid_keyphrase(x:str):
|
| 26 |
-
return x is not None and x != '' and not x.isspace()
|
| 27 |
|
| 28 |
for cluster in clusters:
|
| 29 |
-
|
| 30 |
-
keyphrases = cluster.get_keyphrases()
|
| 31 |
-
keyphrases = list(keyphrases.keys())
|
| 32 |
keyphrases = list(filter(valid_keyphrase,keyphrases))
|
| 33 |
unionfind = UnionFind(keyphrases, condition)
|
| 34 |
unionfind.union_step()
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
|
| 42 |
return clusters
|
| 43 |
|
|
@@ -85,7 +96,7 @@ class LiteratureResearchTool:
|
|
| 85 |
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
|
| 86 |
abstracts = articles.getAbstracts() # List[str]
|
| 87 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
| 88 |
-
clusters = self.__postprocess_clusters__(clusters)
|
| 89 |
return clusters, articles
|
| 90 |
|
| 91 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
@@ -97,7 +108,7 @@ class LiteratureResearchTool:
|
|
| 97 |
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
| 98 |
abstracts = articles.getAbstracts() # List[str]
|
| 99 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
| 100 |
-
clusters = self.__postprocess_clusters__(clusters)
|
| 101 |
return clusters, articles
|
| 102 |
|
| 103 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
@@ -109,7 +120,7 @@ class LiteratureResearchTool:
|
|
| 109 |
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
| 110 |
abstracts = articles.getAbstracts() # List[str]
|
| 111 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
| 112 |
-
clusters = self.__postprocess_clusters__(clusters)
|
| 113 |
return clusters, articles
|
| 114 |
|
| 115 |
if platforn_name == 'IEEE':
|
|
|
|
| 5 |
from .academic_query import AcademicQuery
|
| 6 |
import streamlit as st
|
| 7 |
from tokenizers import Tokenizer
|
| 8 |
+
from .clustering.clusters import KeyphraseCount
|
| 9 |
+
|
| 10 |
|
| 11 |
|
| 12 |
class LiteratureResearchTool:
|
|
|
|
| 15 |
self.cluster_pipeline = ClusterPipeline(cluster_config)
|
| 16 |
|
| 17 |
|
| 18 |
+
def __postprocess_clusters__(self, clusters: ClusterList,query: str) ->ClusterList:
|
| 19 |
'''
|
| 20 |
add top-5 keyphrases to each cluster
|
| 21 |
:param clusters:
|
| 22 |
:return: clusters
|
| 23 |
'''
|
| 24 |
+
def condition(x: KeyphraseCount, y: KeyphraseCount):
|
| 25 |
+
return td.ratcliff_obershelp(x.keyphrase, y.keyphrase) > 0.8
|
| 26 |
+
|
| 27 |
+
def valid_keyphrase(x:KeyphraseCount):
|
| 28 |
+
tmp = x.keyphrase
|
| 29 |
+
return tmp is not None and tmp != '' and not tmp.isspace() and len(tmp)!=1\
|
| 30 |
+
and tmp != query
|
| 31 |
|
|
|
|
|
|
|
| 32 |
|
| 33 |
for cluster in clusters:
|
| 34 |
+
|
| 35 |
+
keyphrases = cluster.get_keyphrases() # [kc]
|
|
|
|
| 36 |
keyphrases = list(filter(valid_keyphrase,keyphrases))
|
| 37 |
unionfind = UnionFind(keyphrases, condition)
|
| 38 |
unionfind.union_step()
|
| 39 |
|
| 40 |
+
tmp = unionfind.get_unions() # dict(root_id = [kc])
|
| 41 |
+
tmp = tmp.values() # [[kc]]
|
| 42 |
+
# [[kc]] -> [ new kc] -> sorted
|
| 43 |
+
tmp = [KeyphraseCount.reduce(x) for x in tmp]
|
| 44 |
+
keyphrases = sorted(tmp,key= lambda x: x.count,reverse=True)[:5]
|
| 45 |
+
keyphrases = [x.keyphrase for x in keyphrases]
|
| 46 |
|
| 47 |
+
# keyphrases = sorted(list(unionfind.get_unions().values()), key=len, reverse=True)[:5] # top-5 keyphrases: list
|
| 48 |
+
# for i in keyphrases:
|
| 49 |
+
# tmp = '/'.join(i)
|
| 50 |
+
# cluster.top_5_keyphrases.append(tmp)
|
| 51 |
+
cluster.top_5_keyphrases = keyphrases
|
| 52 |
|
| 53 |
return clusters
|
| 54 |
|
|
|
|
| 96 |
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
|
| 97 |
abstracts = articles.getAbstracts() # List[str]
|
| 98 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
| 99 |
+
clusters = self.__postprocess_clusters__(clusters,query)
|
| 100 |
return clusters, articles
|
| 101 |
|
| 102 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
|
|
| 108 |
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
| 109 |
abstracts = articles.getAbstracts() # List[str]
|
| 110 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
| 111 |
+
clusters = self.__postprocess_clusters__(clusters,query)
|
| 112 |
return clusters, articles
|
| 113 |
|
| 114 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
|
|
| 120 |
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
| 121 |
abstracts = articles.getAbstracts() # List[str]
|
| 122 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
| 123 |
+
clusters = self.__postprocess_clusters__(clusters,query)
|
| 124 |
return clusters, articles
|
| 125 |
|
| 126 |
if platforn_name == 'IEEE':
|