| import chromadb |
| from chromadb import Documents, EmbeddingFunction, Embeddings |
| from transformers import AutoModel |
| import json |
| from numpy.linalg import norm |
| import sqlite3 |
| import urllib |
|
|
| |
|
|
| class JinaAIEmbeddingFunction(EmbeddingFunction): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def __call__(self, input: Documents) -> Embeddings: |
| embeddings = self.model.encode(input) |
| return embeddings.tolist() |
|
|
| |
| embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', |
| trust_remote_code=True, |
| cache_dir='models') |
|
|
| |
| ef = JinaAIEmbeddingFunction(embedding_model) |
|
|
| |
| topic_descriptions = json.load(open("topic_descriptions.txt")) |
| topics = list(dict.keys(topic_descriptions)) |
| embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions] |
| cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b)) |
|
|
| def choose_topic(summary): |
| embed = embedding_model.encode(summary) |
| topic = "" |
| max_sim = 0. |
| for i,key in enumerate(topics): |
| sim = cos_sim(embed,embeddings[i]) |
| if sim > max_sim: |
| topic = key |
| max_sim = sim |
| return topic |
|
|
| def authors_list_to_str(authors): |
| """input a list of authors, return a string represent authors""" |
| text = "" |
| for author in authors: |
| text+=author+", " |
| return text[:-3] |
|
|
| def authors_str_to_list(string): |
| """input a string of authors, return a list of authors""" |
| authors = [] |
| list_auth = string.split("and") |
| for author in list_auth: |
| if author != "et al.": |
| authors.append(author.strip()) |
| return authors |
|
|
| def chunk_texts(text, max_char=400): |
| """ |
| Chunk a long text into several chunks, with each chunk about 300-400 characters long, |
| but make sure no word is cut in half. |
| Args: |
| text: The long text to be chunked. |
| max_char: The maximum number of characters per chunk (default: 400). |
| Returns: |
| A list of chunks. |
| """ |
| chunks = [] |
| current_chunk = "" |
| words = text.split() |
| for word in words: |
| if len(current_chunk) + len(word) + 1 >= max_char: |
| chunks.append(current_chunk) |
| current_chunk = " " |
| else: |
| current_chunk += " " + word |
| chunks.append(current_chunk.strip()) |
| return chunks |
|
|
| def trimming(txt): |
| start = txt.find("{") |
| end = txt.rfind("}") |
| return txt[start:end+1].replace("\n"," ") |
|
|
| |
|
|
| def extract_tag(txt,tagname): |
| return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")] |
|
|
| def get_record(extract): |
| id = extract_tag(extract,"id") |
| updated = extract_tag(extract,"updated") |
| published = extract_tag(extract,"published") |
| title = extract_tag(extract,"title").replace("\n ","").strip() |
| summary = extract_tag(extract,"summary").replace("\n","").strip() |
| authors = [] |
| while extract.find("<author>")!=-1: |
| |
| author = extract_tag(extract,"name") |
| extract = extract[extract.find("</author>")+9:] |
| authors.append(author) |
| pattern = '<link title="pdf" href="' |
| link_start = extract.find('<link title="pdf" href="') |
| link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2] |
| return [id, updated, published, title, authors, link, summary] |
|
|
| def crawl_exact_paper(title,author,max_results=3): |
| authors = authors_list_to_str(author) |
| records = [] |
| url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results) |
| url = url.replace(" ","%20") |
| try: |
| arxiv_page = urllib.request.urlopen(url,timeout=100).read() |
| xml = str(arxiv_page,encoding="utf-8") |
| while xml.find("<entry>") != -1: |
| extract = xml[xml.find("<entry>")+7:xml.find("</entry>")] |
| xml = xml[xml.find("</entry>")+8:] |
| extract = get_record(extract) |
| topic = choose_topic(extract[6]) |
| records.append([topic,*extract]) |
| return records |
| except Exception as e: |
| return "Error: "+str(e) |
|
|
| def crawl_arxiv(keyword_list, max_results=100): |
| baseurl = 'http://export.arxiv.org/api/query?search_query=' |
| records = [] |
| for i,keyword in enumerate(keyword_list): |
| if i ==0: |
| url = baseurl + 'all:' + keyword |
| else: |
| url = url + '+OR+' + 'all:' + keyword |
| url = url+ '&max_results=' + str(max_results) |
| url = url.replace(' ', '%20') |
| try: |
| arxiv_page = urllib.request.urlopen(url,timeout=100).read() |
| xml = str(arxiv_page,encoding="utf-8") |
| while xml.find("<entry>") != -1: |
| extract = xml[xml.find("<entry>")+7:xml.find("</entry>")] |
| xml = xml[xml.find("</entry>")+8:] |
| extract = get_record(extract) |
| topic = choose_topic(extract[6]) |
| records.append([topic,*extract]) |
| return records |
| except Exception as e: |
| return "Error: "+str(e) |
|
|
| class ArxivSQL: |
| def __init__(self, table="arxivsql", name="arxiv_records_sql"): |
| self.con = sqlite3.connect(name) |
| self.cur = self.con.cursor() |
| self.table = table |
| |
| def query(self, title="", author=[]): |
| if len(title)>0: |
| query_title = 'title like "%{}%"'.format(title) |
| else: |
| query_title = "True" |
| if len(author)>0: |
| query_author = 'authors like ' |
| for auth in author: |
| query_author += "'%{}%' or ".format(auth) |
| query_author = query_author[:-4] |
| else: |
| query_author = "True" |
| query = "select * from {} where {} and {}".format(self.table,query_title,query_author) |
| result = self.cur.execute(query) |
| return result.fetchall() |
|
|
| def query_id(self, ids=[]): |
| try: |
| if len(ids) == 0: |
| return None |
| query = "select * from {} where id in (".format(self.table) |
| for id in ids: |
| query+="'"+id+"'," |
| query = query[:-1] + ")" |
| result = self.cur.execute(query) |
| return result.fetchall() |
| except Exception as e: |
| print(e) |
| print("Error query: ",query) |
| |
| def add(self, crawl_records): |
| """ |
| Add crawl_records (list) obtained from arxiv_crawlers |
| A record is a list of 8 columns: |
| [topic, id, updated, published, title, author, link, summary] |
| Return the final length of the database table |
| """ |
| results = "" |
| for record in crawl_records: |
| try: |
| query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format( |
| record[1][21:], |
| record[0], |
| record[4].replace('"',"'"), |
| authors_list_to_str(record[5]), |
| record[2][:10], |
| record[3][:10], |
| record[6] |
| ) |
| self.cur.execute(query) |
| self.con.commit() |
| except Exception as e: |
| result+=str(e) |
| result+="\n" + query + "\n" |
| finally: |
| return results |
| |
| |
| sqldb = ArxivSQL() |
|
|
| class ArxivChroma: |
| """ |
| Create an interface to arxivdb, which only support query and addition. |
| This interface do not support edition and deletion procedures. |
| """ |
| def __init__(self, table="arxiv_records", name="arxivdb/"): |
| self.client = chromadb.PersistentClient(name) |
| self.model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', |
| trust_remote_code=True, |
| cache_dir='models') |
| self.collection = self.client.get_or_create_collection(table, |
| embedding_function=JinaAIEmbeddingFunction( |
| model = self.model |
| )) |
| |
| def query_relevant(self, keywords, query_texts, n_results=3): |
| """ |
| Perform a query using a list of keywords (str), |
| or using a relavant string |
| """ |
| contains = [] |
| for keyword in keywords: |
| contains.append({"$contains":keyword.lower()}) |
| return self.collection.query( |
| query_texts=query_texts, |
| where_document={ |
| "$or":contains |
| }, |
| n_results=n_results, |
| ) |
| |
| def query_exact(self, id): |
| ids = ["{}_{}".format(id,j) for j in range(0,10)] |
| return self.collection.get(ids=ids) |
|
|
| def add(self, crawl_records): |
| """ |
| Add crawl_records (list) obtained from arxiv_crawlers |
| A record is a list of 8 columns: |
| [topic, id, updated, published, title, author, link, summary] |
| Return the final length of the database table |
| """ |
| for record in crawl_records: |
| embed_text = """ |
| Topic: {}, |
| Title: {}, |
| Summary: {} |
| """.format(record[0],record[4],record[7]) |
| chunks = chunk_texts(embed_text) |
| ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))] |
| paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))] |
| self.collection.add( |
| documents = chunks, |
| metadatas=paper_ids, |
| ids = ids |
| ) |
| return self.collection.count() |
|
|
| |
| db = ArxivChroma() |
|
|
|
|