| # import random | |
| # import sqlite3 | |
| # import time | |
| # from googleapiclient.discovery import build | |
| # from google.oauth2 import service_account | |
| # from googleapiclient.errors import HttpError | |
| # import pandas as pd | |
| # import requests | |
| # from bs4 import BeautifulSoup | |
| # import pickle | |
| # import tldextract | |
| import os | |
| from dotenv import load_dotenv | |
| # from langchain.schema import Document | |
| # from langchain.vectorstores.utils import DistanceStrategy | |
| # from torch import cuda, bfloat16 | |
| # import torch | |
| # import transformers | |
| # from transformers import AutoTokenizer | |
| # from langchain.document_loaders import TextLoader | |
| # from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.llms import LlamaCpp | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA # RetrievalQAWithSourcesChain | |
| # from config import IFCN_LIST_URL | |
| IFCN_FILENAME = os.path.join(os.path.dirname(os.path.dirname(__file__)), | |
| 'ifcn_df.csv') | |
| load_dotenv() | |
| DB_PATH = os.getenv('DB_PATH') | |
| FAISS_DB_PATH = os.getenv('FAISS_DB_PATH') | |
| MODEL_PATH = os.getenv('MODEL_PATH') | |
| # def get_claims(claims_serv, query_str, lang_code): | |
| # """Queries the Google Fact Check API using the search string and returns the results | |
| # Args: | |
| # claims_serv (build().claims() object): build() creates a service object \ | |
| # for the factchecktools API; claims() creates a 'claims' object which \ | |
| # can be used to query with the search string | |
| # query_str (str): the query string | |
| # lang_code (str): BCP-47 language code, used to restrict search results by language | |
| # Returns: | |
| # list: the list of all search results returned by the API | |
| # """ | |
| # claims = [] | |
| # req = claims_serv.search(query=query_str, languageCode=lang_code) | |
| # try: | |
| # res = req.execute() | |
| # claims = res['claims'] # FIXME: is returning KeyError, perhaps when Google API is unresponsive | |
| # except HttpError as e: | |
| # print('Error response status code : {0}, reason : {1}'.format(e.status_code, e.error_details)) | |
| # # Aggregate all the results pages into one object | |
| # while 'nextPageToken' in res.keys(): | |
| # req = claims_serv.search_next(req, res) | |
| # res = req.execute() | |
| # claims.extend(res['claims']) | |
| # # TODO: Also return any basic useful metrics based on the results | |
| # return claims | |
| # def reformat_claims(claims): | |
| # """Reformats the list of nested claims / search results into a DataFrame | |
| # Args: | |
| # claims (list): list of nested claims / search results | |
| # Returns: | |
| # pd.DataFrame: DataFrame containing search results, one per each row | |
| # """ | |
| # # Format the results object into a format that is convenient to use | |
| # df = pd.DataFrame(claims) | |
| # df = df.explode('claimReview').reset_index(drop=True) | |
| # claim_review_df = pd.json_normalize(df['claimReview']) | |
| # return pd.concat([df.drop('claimReview', axis=1), claim_review_df], axis=1) | |
| # def certify_claims(claims_df): | |
| # """Certifies all the search results from the API against a list of verified IFCN signatories | |
| # Args: | |
| # claims_df (pd.DataFrame): DataFrame object containing all search results from the API | |
| # Returns: | |
| # pd.DataFrame: claims dataframe filtered to include only IFCN-certified claims | |
| # """ | |
| # ifcn_to_use = get_ifcn_to_use() | |
| # claims_df['ifcn_check'] = claims_df['publisher.site'].apply(remove_subdomain).isin(ifcn_to_use) | |
| # return claims_df[claims_df['ifcn_check']].drop('ifcn_check', axis=1) | |
| # def get_ifcn_data(): | |
| # """Standalone function to update the IFCN signatories CSV file that is stored locally""" | |
| # r = requests.get(IFCN_LIST_URL) | |
| # soup = BeautifulSoup(r.content, 'html.parser') | |
| # cats_list = soup.find_all('div', class_='row mb-5') | |
| # active = cats_list[0].find_all('div', class_='media') | |
| # active = extract_ifcn_df(active, 'active') | |
| # under_review = cats_list[1].find_all('div', class_='media') | |
| # under_review = extract_ifcn_df(under_review, 'under_review') | |
| # expired = cats_list[2].find_all('div', class_='media') | |
| # expired = extract_ifcn_df(expired, 'expired') | |
| # ifcn_df = pd.concat([active, under_review, expired], axis=0, ignore_index=True) | |
| # ifcn_df['country'] = ifcn_df['country'].str.strip('from ') | |
| # ifcn_df['verified_date'] = ifcn_df['verified_date'].str.strip('Verified on ') | |
| # ifcn_df.to_csv(IFCN_FILENAME, index=False) | |
| # def extract_ifcn_df(ifcn_list, status): | |
| # """Returns useful info from a list of IFCN signatories | |
| # Args: | |
| # ifcn_list (list): list of IFCN signatories | |
| # status (str): status code to be used for all signatories in this list | |
| # Returns: | |
| # pd.DataFrame: a dataframe of IFCN signatories' data | |
| # """ | |
| # ifcn_data = [{ | |
| # 'url': x.a['href'], | |
| # 'name': x.h5.text, | |
| # 'country': x.h6.text, | |
| # 'verified_date': x.find_all('span', class_='small')[1].text, | |
| # 'ifcn_profile_url': | |
| # x.find('a', class_='btn btn-sm btn-outline btn-link mb-0')['href'], | |
| # 'status': status | |
| # } for x in ifcn_list] | |
| # return pd.DataFrame(ifcn_data) | |
| # def remove_subdomain(url): | |
| # """Removes the subdomain from a URL hostname - useful when comparing two URLs | |
| # Args: | |
| # url (str): URL hostname | |
| # Returns: | |
| # str: URL with subdomain removed | |
| # """ | |
| # extract = tldextract.extract(url) | |
| # return extract.domain + '.' + extract.suffix | |
| # def get_ifcn_to_use(): | |
| # """Returns the IFCN data for non-expired signatories | |
| # Returns: | |
| # pd.Series: URls of non-expired IFCN signatories | |
| # """ | |
| # ifcn_df = pd.read_csv(IFCN_FILENAME) | |
| # ifcn_url = ifcn_df.loc[ifcn_df.status.isin(['active', 'under_review']), 'url'] | |
| # return [remove_subdomain(x) for x in ifcn_url] | |
| # def get_gapi_service(): | |
| # """Returns a Google Fact-Check API-specific service object used to query the API | |
| # Returns: | |
| # googleapiclient.discovery.Resource: API-specific service object | |
| # """ | |
| # load_dotenv() | |
| # environment = os.getenv('ENVIRONMENT') | |
| # if environment == 'DEVELOPMENT': | |
| # api_key = os.getenv('API_KEY') | |
| # service = build('factchecktools', 'v1alpha1', developerKey=api_key) | |
| # elif environment == 'PRODUCTION': | |
| # google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') | |
| # # FIXME: The below credentials not working, the HTTP request throws HTTPError 400 | |
| # # credentials = service_account.Credentials.from_service_account_file( | |
| # # GOOGLE_APPLICATION_CREDENTIALS) | |
| # credentials = service_account.Credentials.from_service_account_file( | |
| # google_application_credentials, | |
| # scopes=['https://www.googleapis.com/auth/userinfo.email', | |
| # 'https://www.googleapis.com/auth/cloud-platform']) | |
| # service = build('factchecktools', 'v1alpha1', credentials=credentials) | |
| # return service | |
| # # USED IN update_database.py ---- | |
| # def get_claims_by_site(claims_serv, publisher_site, lang_code): | |
| # # TODO: Any HTTP or other errors in this function need to be handled better | |
| # req = claims_serv.search(reviewPublisherSiteFilter=publisher_site, | |
| # languageCode=lang_code) | |
| # while True: | |
| # try: | |
| # res = req.execute() | |
| # break | |
| # except HttpError as e: | |
| # print('Error response status code : {0}, reason : {1}'. | |
| # format(e.status_code, e.error_details)) | |
| # time.sleep(random.randint(50, 60)) | |
| # if 'claims' in res: | |
| # claims = res['claims'] # FIXME: is returning KeyError when Google API is unresponsive? | |
| # print('first 10') | |
| # req_prev, req = req, None | |
| # res_prev, res = res, None | |
| # else: | |
| # print('No data') | |
| # return [] | |
| # # Aggregate all the results pages into one object | |
| # while 'nextPageToken' in res_prev.keys(): | |
| # req = claims_serv.search_next(req_prev, res_prev) | |
| # try: | |
| # res = req.execute() | |
| # claims.extend(res['claims']) | |
| # req_prev, req = req, None | |
| # res_prev, res = res, None | |
| # print('another 10') | |
| # except HttpError as e: | |
| # print('Error in while loop : {0}, \ | |
| # reason : {1}'.format(e.status_code, e.error_details)) | |
| # time.sleep(random.randint(50, 60)) | |
| # return claims | |
| # def rename_claim_attrs(df): | |
| # return df.rename( | |
| # columns={'claimDate': 'claim_date', | |
| # 'reviewDate': 'review_date', | |
| # 'textualRating': 'textual_rating', | |
| # 'languageCode': 'language_code', | |
| # 'publisher.name': 'publisher_name', | |
| # 'publisher.site': 'publisher_site'} | |
| # ) | |
| # def clean_claims(df): | |
| # pass | |
| # def write_claims_to_db(df): | |
| # with sqlite3.connect(DB_PATH) as db_con: | |
| # df.to_sql('claims', db_con, if_exists='append', index=False) | |
| # # FIXME: The id variable is not getting auto-incremented | |
| # def generate_and_store_embeddings(df, embed_model, overwrite): | |
| # # TODO: Combine "text" & "textual_rating" to generate useful statements | |
| # df['fact_check'] = 'The fact-check result for the claim "' + df['text'] \ | |
| # + '" is "' + df['textual_rating'] + '"' | |
| # # TODO: Are ids required? | |
| # df.rename(columns={'text': 'claim'}, inplace=True) | |
| # docs = \ | |
| # [Document(page_content=row['fact_check'], | |
| # metadata=row.drop('fact_check').to_dict()) | |
| # for idx, row in df.iterrows()] | |
| # if overwrite == True: | |
| # db = FAISS.from_documents(docs, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) | |
| # # FIXME: MAX_INNER_PRODUCT is not being used currently, only EUCLIDEAN_DISTANCE | |
| # db.save_local(FAISS_DB_PATH) | |
| # elif overwrite == False: | |
| # db = FAISS.load_local(FAISS_DB_PATH, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) | |
| # db.add_documents(docs) | |
| # db.save_local(FAISS_DB_PATH) | |
| def get_rag_chain(): | |
| model_name = "sentence-transformers/all-mpnet-base-v2" | |
| model_kwargs = {"device": "cpu"} | |
| embed_model = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
| llm = LlamaCpp(model_path=MODEL_PATH) | |
| db_vector = FAISS.load_local(FAISS_DB_PATH, embed_model, allow_dangerous_deserialization=True) | |
| retriever = db_vector.as_retriever() | |
| return RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| verbose=True | |
| ) | |