Spaces:
Sleeping
Sleeping
Asaad Almutareb commited on
Commit ·
2e6490e
1
Parent(s): 5c0a79d
added sqlite schema and handling
Browse files
innovation_pathfinder_ai/database/db_handler.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlmodel import SQLModel, create_engine, Session, select
|
| 2 |
+
from innovation_pathfinder_ai.database.schema import Sources
|
| 3 |
+
from innovation_pathfinder_ai.utils.logger import get_console_logger
|
| 4 |
+
|
| 5 |
+
sqlite_file_name = "database.db"
|
| 6 |
+
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
| 7 |
+
engine = create_engine(sqlite_url, echo=False)
|
| 8 |
+
|
| 9 |
+
logger = get_console_logger("db_handler")
|
| 10 |
+
|
| 11 |
+
SQLModel.metadata.create_all(engine)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def read_one(hash_id: dict):
|
| 15 |
+
with Session(engine) as session:
|
| 16 |
+
statement = select(Sources).where(Sources.hash_id == hash_id)
|
| 17 |
+
sources = session.exec(statement).first()
|
| 18 |
+
return sources
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def add_one(data: dict):
|
| 22 |
+
with Session(engine) as session:
|
| 23 |
+
if session.exec(
|
| 24 |
+
select(Sources).where(Sources.hash_id == data.get("hash_id"))
|
| 25 |
+
).first():
|
| 26 |
+
logger.warning(f"Item with hash_id {data.get('hash_id')} already exists")
|
| 27 |
+
return None # or raise an exception, or handle as needed
|
| 28 |
+
sources = Sources(**data)
|
| 29 |
+
session.add(sources)
|
| 30 |
+
session.commit()
|
| 31 |
+
session.refresh(sources)
|
| 32 |
+
logger.info(f"Item with hash_id {data.get('hash_id')} added to the database")
|
| 33 |
+
return sources
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def update_one(hash_id: dict, data: dict):
|
| 37 |
+
with Session(engine) as session:
|
| 38 |
+
# Check if the item with the given hash_id exists
|
| 39 |
+
sources = session.exec(
|
| 40 |
+
select(Sources).where(Sources.hash_id == hash_id)
|
| 41 |
+
).first()
|
| 42 |
+
if not sources:
|
| 43 |
+
logger.warning(f"No item with hash_id {hash_id} found for update")
|
| 44 |
+
return None # or raise an exception, or handle as needed
|
| 45 |
+
for key, value in data.items():
|
| 46 |
+
setattr(sources, key, value)
|
| 47 |
+
session.commit()
|
| 48 |
+
logger.info(f"Item with hash_id {hash_id} updated in the database")
|
| 49 |
+
return sources
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def delete_one(id: int):
|
| 53 |
+
with Session(engine) as session:
|
| 54 |
+
# Check if the item with the given hash_id exists
|
| 55 |
+
sources = session.exec(
|
| 56 |
+
select(Sources).where(Sources.hash_id == id)
|
| 57 |
+
).first()
|
| 58 |
+
if not sources:
|
| 59 |
+
logger.warning(f"No item with hash_id {id} found for deletion")
|
| 60 |
+
return None # or raise an exception, or handle as needed
|
| 61 |
+
session.delete(sources)
|
| 62 |
+
session.commit()
|
| 63 |
+
logger.info(f"Item with hash_id {id} deleted from the database")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def add_many(data: list):
|
| 67 |
+
with Session(engine) as session:
|
| 68 |
+
for info in data:
|
| 69 |
+
# Reuse add_one function for each item
|
| 70 |
+
result = add_one(info)
|
| 71 |
+
if result is None:
|
| 72 |
+
logger.warning(
|
| 73 |
+
f"Item with hash_id {info.get('hash_id')} could not be added"
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
logger.info(
|
| 77 |
+
f"Item with hash_id {info.get('hash_id')} added to the database"
|
| 78 |
+
)
|
| 79 |
+
session.commit() # Commit at the end of the loop
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def delete_many(ids: list):
|
| 83 |
+
with Session(engine) as session:
|
| 84 |
+
for id in ids:
|
| 85 |
+
# Reuse delete_one function for each item
|
| 86 |
+
result = delete_one(id)
|
| 87 |
+
if result is None:
|
| 88 |
+
logger.warning(f"No item with hash_id {id} found for deletion")
|
| 89 |
+
else:
|
| 90 |
+
logger.info(f"Item with hash_id {id} deleted from the database")
|
| 91 |
+
session.commit() # Commit at the end of the loop
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def read_all(query: dict = None):
|
| 95 |
+
with Session(engine) as session:
|
| 96 |
+
statement = select(Sources)
|
| 97 |
+
if query:
|
| 98 |
+
statement = statement.where(
|
| 99 |
+
*[getattr(Sources, key) == value for key, value in query.items()]
|
| 100 |
+
)
|
| 101 |
+
sources = session.exec(statement).all()
|
| 102 |
+
return sources
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def delete_all():
|
| 106 |
+
with Session(engine) as session:
|
| 107 |
+
session.exec(Sources).delete()
|
| 108 |
+
session.commit()
|
| 109 |
+
logger.info("All items deleted from the database")
|
innovation_pathfinder_ai/database/schema.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlmodel import SQLModel, Field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import datetime
|
| 5 |
+
|
| 6 |
+
class Sources(SQLModel, table=True):
|
| 7 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 8 |
+
url: str = Field()
|
| 9 |
+
title: Optional[str] = Field(default="NA", unique=False)
|
| 10 |
+
hash_id: str = Field(unique=True)
|
| 11 |
+
created_at: float = Field(default=datetime.datetime.now().timestamp())
|
| 12 |
+
summary: str = Field(default="")
|
| 13 |
+
embedded: bool = Field(default=False)
|
| 14 |
+
|
| 15 |
+
__table_args__ = {"extend_existing": True}
|
innovation_pathfinder_ai/structured_tools/structured_tools.py
CHANGED
|
@@ -6,31 +6,32 @@ from langchain_community.utilities import WikipediaAPIWrapper
|
|
| 6 |
#from langchain.tools import Tool
|
| 7 |
from langchain_community.utilities import GoogleSearchAPIWrapper
|
| 8 |
import arxiv
|
| 9 |
-
|
| 10 |
# hacky and should be replaced with a database
|
| 11 |
from innovation_pathfinder_ai.source_container.container import (
|
| 12 |
all_sources
|
| 13 |
)
|
| 14 |
-
from innovation_pathfinder_ai.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
@tool
|
| 17 |
def arxiv_search(query: str) -> str:
|
| 18 |
"""Search arxiv database for scientific research papers and studies. This is your primary information source.
|
| 19 |
always check it first when you search for information, before using any other tool."""
|
| 20 |
-
# return "LangChain"
|
| 21 |
global all_sources
|
| 22 |
-
arxiv_retriever = ArxivRetriever(load_max_docs=
|
| 23 |
data = arxiv_retriever.invoke(query)
|
| 24 |
meta_data = [i.metadata for i in data]
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# formatted_info = format_info_list(all_sources)
|
| 32 |
-
|
| 33 |
-
return meta_data.__str__()
|
| 34 |
|
| 35 |
@tool
|
| 36 |
def get_arxiv_paper(paper_id:str) -> None:
|
|
@@ -52,17 +53,13 @@ def get_arxiv_paper(paper_id:str) -> None:
|
|
| 52 |
@tool
|
| 53 |
def google_search(query: str) -> str:
|
| 54 |
"""Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
|
| 55 |
-
# return "LangChain"
|
| 56 |
global all_sources
|
| 57 |
|
| 58 |
websearch = GoogleSearchAPIWrapper()
|
| 59 |
-
search_results:dict = websearch.results(query,
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# formatted_string = "Title: {title}, link: {link}, snippet: {snippet}".format(**organic_source)
|
| 64 |
-
cleaner_sources = ["Title: {title}, link: {link}, snippet: {snippet}".format(**i) for i in search_results]
|
| 65 |
-
|
| 66 |
all_sources += cleaner_sources
|
| 67 |
|
| 68 |
return cleaner_sources.__str__()
|
|
@@ -75,5 +72,9 @@ def wikipedia_search(query: str) -> str:
|
|
| 75 |
api_wrapper = WikipediaAPIWrapper()
|
| 76 |
wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
|
| 77 |
wikipedia_results = wikipedia_search.run(query)
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
#from langchain.tools import Tool
|
| 7 |
from langchain_community.utilities import GoogleSearchAPIWrapper
|
| 8 |
import arxiv
|
| 9 |
+
import ast
|
| 10 |
# hacky and should be replaced with a database
|
| 11 |
from innovation_pathfinder_ai.source_container.container import (
|
| 12 |
all_sources
|
| 13 |
)
|
| 14 |
+
from innovation_pathfinder_ai.utils.utils import (
|
| 15 |
+
parse_list_to_dicts, format_wiki_summaries, format_arxiv_documents, format_search_results
|
| 16 |
+
)
|
| 17 |
+
from innovation_pathfinder_ai.database.db_handler import (
|
| 18 |
+
add_many
|
| 19 |
+
)
|
| 20 |
|
| 21 |
@tool
|
| 22 |
def arxiv_search(query: str) -> str:
|
| 23 |
"""Search arxiv database for scientific research papers and studies. This is your primary information source.
|
| 24 |
always check it first when you search for information, before using any other tool."""
|
|
|
|
| 25 |
global all_sources
|
| 26 |
+
arxiv_retriever = ArxivRetriever(load_max_docs=3)
|
| 27 |
data = arxiv_retriever.invoke(query)
|
| 28 |
meta_data = [i.metadata for i in data]
|
| 29 |
+
formatted_sources = format_arxiv_documents(data)
|
| 30 |
+
all_sources += formatted_sources
|
| 31 |
+
parsed_sources = parse_list_to_dicts(formatted_sources)
|
| 32 |
+
add_many(parsed_sources)
|
| 33 |
+
|
| 34 |
+
return data.__str__()
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
@tool
|
| 37 |
def get_arxiv_paper(paper_id:str) -> None:
|
|
|
|
| 53 |
@tool
|
| 54 |
def google_search(query: str) -> str:
|
| 55 |
"""Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
|
|
|
|
| 56 |
global all_sources
|
| 57 |
|
| 58 |
websearch = GoogleSearchAPIWrapper()
|
| 59 |
+
search_results:dict = websearch.results(query, 3)
|
| 60 |
+
cleaner_sources =format_search_results(search_results)
|
| 61 |
+
parsed_csources = parse_list_to_dicts(cleaner_sources)
|
| 62 |
+
add_many(parsed_csources)
|
|
|
|
|
|
|
|
|
|
| 63 |
all_sources += cleaner_sources
|
| 64 |
|
| 65 |
return cleaner_sources.__str__()
|
|
|
|
| 72 |
api_wrapper = WikipediaAPIWrapper()
|
| 73 |
wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
|
| 74 |
wikipedia_results = wikipedia_search.run(query)
|
| 75 |
+
formatted_summaries = format_wiki_summaries(wikipedia_results)
|
| 76 |
+
all_sources += formatted_summaries
|
| 77 |
+
parsed_summaries = parse_list_to_dicts(formatted_summaries)
|
| 78 |
+
add_many(parsed_summaries)
|
| 79 |
+
|
| 80 |
+
return wikipedia_results.__str__()
|