ai_workflows / app /workflows /research_article_suggester.py
theRealNG's picture
restructure the codebase
1f39bb8
from crewai import Agent, Task, Crew
from langchain_openai import ChatOpenAI
from tavily import TavilyClient
from semanticscholar import SemanticScholar
import arxiv
import os
import json
from pydantic import BaseModel, Field
from crewai.tasks.task_output import TaskOutput
from datetime import datetime, timedelta
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_core.output_parsers import JsonOutputParser
from workflows.tools.scrape_website import scrape_tool, CustomScrapeWebsiteTool
MAX_RESULTS = 2
AGE_OF_RESEARCH_PAPER = 60
class RecentArticleSuggester:
"""
Suggests recent research papers based on a given topic.
"""
def __init__(self):
self.tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
def kickoff(self, inputs={}):
self.topic = inputs["topic"]
suggested_research_papers = self._suggest_research_papers()
return suggested_research_papers
def _suggest_research_papers(self):
query = f"research papers on {self.topic} published in the last week"
results = []
print("\nSearching for papers on Tavily...")
results = self.tavily_client.search(
query, max_results=MAX_RESULTS)['results']
print("\nSearching for papers on Arxiv...")
arxiv_results = arxiv.Search(
query=self.topic,
max_results=MAX_RESULTS,
sort_by=arxiv.SortCriterion.SubmittedDate
)
for result in arxiv_results.results():
paper = {
"title": result.title,
"authors": ", ".join(str(author) for author in result.authors),
"content": result.summary,
# "published_on": result.submitted.date(),
"url": result.entry_id,
"pdf_url": result.pdf_url
}
results.append(paper)
print("\nSearching for papers on Semanticscholar...")
sch = SemanticScholar()
semantic_results = sch.search_paper(
self.topic, sort='publicationDate:desc', bulk=True,
fields=['title', 'url', 'authors', 'publicationDate', 'abstract'])
for result in semantic_results[:MAX_RESULTS]:
paper = {
"title": result.title,
"authors": ", ".join(str(author.name) for author in result.authors),
"content": result.abstract,
"published_on": result.publicationDate,
"url": result.url,
}
results.append(paper)
# pitch_crew = self._create_pitch_crew()
research_paper_suggestions = []
for result in results:
try:
info = self._article_pitch(result)
# info = pitch_crew.kickoff(inputs={
# "title": result["title"],
# "url": result["url"],
# "content": result["content"]
# })
if info is not None:
research_paper_suggestions = research_paper_suggestions + \
[info]
except BaseException as e:
print(
f"Error processing article '{result['title']}': {e}\n\n {e.__traceback__}")
return research_paper_suggestions
def _gather_information(self, article):
print(f"\nScraping website: {article['url']}")
article_content = CustomScrapeWebsiteTool(article["url"])
print(f"\nGathering information from website: {article['url']}")
parser = JsonOutputParser(pydantic_object=ResearchPaper)
prompt_template = ChatPromptTemplate.from_messages([
SystemMessage(
"You are Research Paper Information Retriever. You are an expert in gathering required details about the given research paper."
"Your personal goal is: Retrieve the author information and date the research paper was published in the format of dd/mm/yyyy."
f"Formatting Instructions: {parser.get_format_instructions()}"
),
HumanMessage(
f"Here is the information about the research paper:\n {article}\n\n"
f"Research Paper content:\n{article_content}"
)
])
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.2)
information_scrapper_chain = prompt_template | llm | parser
article_info = information_scrapper_chain.invoke({})
print("\nGathered Article Info: ", article_info)
article_info['article_content'] = article_content
return article_info
def _article_pitch(self, article):
article_info = self._gather_information(article)
try:
date_obj = datetime.strptime(
article_info['published_on'], "%d/%m/%Y")
start_date = datetime.now() - timedelta(days=AGE_OF_RESEARCH_PAPER)
# Compare if the input date is older
if date_obj < start_date:
print(
f"\nRejecting research paper {article['title']} because it was published on {date_obj},"
f" which is before the expected timeframe {start_date} & {datetime.now()}")
return None
except ValueError:
print("Invalid date format. Please use dd/mm/yyyy.")
return None
print(f"\nCreating pitch for the research paper: {article['title']}")
pitch_parser = JsonOutputParser(pydantic_object=ResearchPaperWithPitch)
pitch_template = ChatPromptTemplate.from_messages([
SystemMessage(
"You are Curiosity Catalyst. As a Curiosity Catalyst, you know exactly how to pique the user's curiosity to read the research paper."
"Your personal goal is: To pique the user's curiosity to read the research paper."
"Read the Research Paper Content to create a pitch."
f"Formatting Instructions: {pitch_parser.get_format_instructions()}"
),
HumanMessage(
f"Here is the information about the research paper:\n {article_info}\n\n"
f"Research Paper content:\n{article_info['article_content']}"
)
])
pitch_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.2)
pitcher_chain = pitch_template | pitch_llm | pitch_parser
article_pitch = pitcher_chain.invoke({})
print("\nResearch Paper with the pitch: ", article_pitch)
return article_pitch
# Deprecated
def _create_pitch_crew(self):
information_gatherer = Agent(
role="Research Paper Information Retriever",
goal="Gather required information for the given research papers.",
verbose=True,
backstory=(
"You are an expert in gathering required details "
"about the given research paper."
),
llm=ChatOpenAI(model="gpt-3.5-turbo", temperature=0.2),
tools=[scrape_tool],
)
def evaluator(output: TaskOutput):
article_info = json.loads(output.exported_output)
try:
date_obj = datetime.strptime(
article_info['published_on'], "%d/%m/%Y")
start_date = datetime.now() - timedelta(days=AGE_OF_RESEARCH_PAPER)
# Compare if the input date is older
if date_obj < start_date:
raise BaseException(
f"{date_obj} Older than given timeframe {start_date}")
except ValueError:
print("Invalid date format. Please use dd/mm/yyyy.")
return False
information_gathering_task = Task(
description=(
"Here is the information of a research paper: title {title}, "
"url: {url} and content: {content}.\n"
"Gather following information about the research paper: "
"1. When was the research paper published and present it in dd/mm/yyyy format. "
"2. Who is the author of the research paper. "
),
expected_output=(
"Following details of the research paper: title, url, "
"content/summary, date it was published and author."
),
agent=information_gatherer,
async_exection=False,
output_json=ResearchPaper,
callback=evaluator,
)
pitcher = Agent(
role="Curiosity Catalyst",
goal="To pique the user's curiosity to read the research paper.",
verbose=True,
backstory=(
"As a Curiosity Catalyst, you know exactly how to pique the user's curiosity "
"to read the research paper."
),
llm=ChatOpenAI(model="gpt-3.5-turbo", temperature=0.2),
tools=[scrape_tool],
)
create_pitch = Task(
description=(
"Craft the pitch so to that it teases the research paper's most intriguing aspects, "
"by posing questions that the research paper might answer or "
"highlighting surprising facts to pique the user's curiosity "
" to read the research paper so that he is up-to-date with latest research."
),
expected_output=(
"All the details of the research paper along with the pitch."
),
tools=[scrape_tool],
agent=pitcher,
context=[information_gathering_task],
output_json=ResearchPaperWithPitch,
)
crew = Crew(
agents=[information_gatherer, pitcher],
tasks=[information_gathering_task, create_pitch],
verbose=True,
max_rpm=4,
)
return crew
class ResearchPaper(BaseModel):
title: str
url: str
summary: str
author: str = Field(description="author of the article")
published_on: str = Field(
description="Date the article was publised on in foramt dd/mm/yyyy")
class ResearchPaperWithPitch(BaseModel):
title: str
url: str
summary: str
author: str = Field(description="author of the article")
published_on: str = Field(
description="Date the article was publised on in foramt dd/mm/yyyy")
pitch: str