Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import time | |
| import json | |
| import joblib | |
| import math | |
| import itertools | |
| import argparse | |
| import multiprocessing as mp | |
| from typing import List | |
| from pathlib import Path | |
| import yaml | |
| import jinja2 | |
| import requests | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from serpapi import GoogleSearch | |
| import tiktoken | |
| from openai import OpenAI | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from model import llm | |
| from data import get_leads, format_search_results | |
| from utils import (parse_json_garbage, split_dataframe, merge_results, | |
| combine_results, split_dict, format_df, | |
| clean_quotes, compose_query, reverse_category2supercategory) | |
| from batch import postprocess_result | |
| from pipeline import (get_serp, get_condensed_result, get_organic_result, get_googlemap_results, | |
| crawl_results, crawl_results_mp, | |
| compose_extraction, extract_results, extract_results_mp, | |
| compose_classification, classify_results, classify_results_mp, | |
| compose_regularization, regularize_results, regularize_results_mp, | |
| compose_filter, filter_results, filter_results_mp) | |
| load_dotenv() | |
| ORGANIZATION_ID = os.getenv('OPENAI_ORGANIZATION_ID') | |
| SERP_API_KEY = os.getenv('SERP_APIKEY') | |
| SERPER_API_KEY = os.getenv('SERPER_API_KEY') | |
| def continue_missing(args): | |
| """ | |
| """ | |
| data = get_leads(args.data_path) | |
| n_data = data.shape[0] | |
| formatted_results = pd.read_csv(os.path.join( args.output_dir, args.formatted_results_path)) | |
| missing_indices = [] | |
| for i in range(n_data): | |
| if i not in formatted_results['index'].unique(): | |
| logger.debug(f"{i} is not found") | |
| missing_indices.append(i) | |
| if len(missing_indices)==0: | |
| logger.debug("No missing data") | |
| return | |
| missing_data = data.loc[missing_indices] | |
| if not os.path.exists(args.output_missing_dir): | |
| os.makedirs(args.output_missing_dir) | |
| missing_data.to_csv( args.missing_data_path, index=False, header=False) | |
| args.data_path = args.missing_data_path | |
| args.output_dir = args.output_missing_dir | |
| if missing_data.shape[0]<args.n_processes: | |
| args.n_processes = 1 | |
| main(args) | |
| def main(args): | |
| """ | |
| Argument | |
| args: argparse | |
| Note | |
| 200 records | |
| crawl: 585.3285548686981 | |
| extract: 2791.631685256958(delay = 10) | |
| classify: 2374.4915606975555(delay = 10) | |
| """ | |
| steps = args.steps | |
| crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path) if args.crawled_file_path is not None else None | |
| extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path) if args.extracted_file_path is not None else None | |
| # classified_file_path = os.path.join( args.output_dir, args.classified_file_path) | |
| # combined_file_path = os.path.join( args.output_dir, args.combined_file_path) | |
| postprocessed_file_path = os.path.join( args.output_dir, args.postprocessed_file_path) if args.postprocessed_file_path is not None else None | |
| # formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path) | |
| filtered_file_path = os.path.join( args.output_dir, args.filtered_file_path) if args.filtered_file_path is not None else None | |
| regularized_file_path = os.path.join( args.output_dir, args.regularized_file_path) if args.regularized_file_path is not None else None | |
| ## 讀取資料名單 ## | |
| data = get_leads(args.data_path) | |
| ## 進行爬蟲與分析 ## | |
| if steps=='all' or steps=='crawl': | |
| Path(crawled_file_path).parent.mkdir(parents=True, exist_ok=True) | |
| crawled_results = crawl_results_mp( | |
| data, | |
| crawled_file_path, | |
| serp_provider=args.serp_provider, | |
| n_processes=args.n_processes | |
| ) | |
| else: | |
| sys.exit(0) | |
| # crawled_results = { k:v[-5:] for k,v in crawled_results.items()} | |
| # crawled_results['crawled_results'].to_csv( formatted_results_path, index=False) | |
| ## 篩選爬蟲結果 ## | |
| # filtered_results = filter_results_mp( | |
| # data = crawled_results['crawled_results'], | |
| # filtered_file_path = filtered_file_path, | |
| # provider = args.filter_provider, | |
| # model = args.filter_model, | |
| # n_processes = args.n_processes | |
| # ) | |
| # sys.exit(0) | |
| ## 方法 1: 擷取關鍵資訊與分類 ## | |
| if steps=='all' or steps=='extract': | |
| assert os.path.exists(crawled_file_path), f"# CRAWLED file not found: {crawled_file_path}" | |
| crawled_results = joblib.load( open(crawled_file_path, "rb")) | |
| extracted_results = extract_results_mp( | |
| crawled_results = crawled_results['crawled_results'], # filtered_results['filtered_results'], # crawled_results['crawled_results'], | |
| extracted_file_path = extracted_file_path, | |
| classes = args.classes, | |
| provider = args.extraction_provider, # 'openai', # args.provider, | |
| model = args.extraction_model, # 'gpt-3.5-turbo-0125', # args.model, | |
| n_processes = args.n_processes | |
| ) | |
| else: | |
| sys.exit(0) | |
| ## 方法2: 直接對爬蟲結果分類 ## | |
| # classified_results = classify_results_mp( | |
| # extracted_results['extracted_results'], | |
| # classified_file_path, | |
| # classes = args.classes, | |
| # backup_classes = args.backup_classes, | |
| # provider = args.provider, | |
| # model = args.model, | |
| # n_processes = args.n_processes | |
| # ) | |
| ## 合併分析結果 ## | |
| # combined_results = combine_results( | |
| # classified_results['classified_results'], | |
| # combined_file_path, | |
| # src_column = 'classified_category', | |
| # tgt_column = 'category', | |
| # strategy = args.strategy | |
| # ) | |
| ## 正規化分類結果 ## | |
| if steps=='all' or steps=='regularize': | |
| assert os.path.exists(args.extracted_file_path), f"# extracted result file not found: {args.extracted_file_path}" | |
| extracted_results = joblib.load( open(extracted_file_path, "rb")) | |
| regularize_results = regularize_results_mp( | |
| extracted_results['extracted_results'], | |
| regularized_file_path, | |
| provider = args.regularization_provider, # 'google', # 'openai', # args.provider, | |
| model = args.regularization_model # 'gemini-1.5-flash' # 'gpt-3.5-turbo-0125' # args.model | |
| ) | |
| else: | |
| sys.exit(0) | |
| ## 後處理分析結果 ## | |
| if steps=='all' or steps=='postprocess': | |
| assert os.path.exists(args.regularized_file_path), f"# extracted result file not found: {args.extracted_file_path}" | |
| regularize_results = joblib.load( open(regularized_file_path, "rb")) | |
| postprossed_results = postprocess_result( | |
| regularize_results['regularized_results'], # extracted_results['extracted_results'], # combined_results, | |
| postprocessed_file_path, | |
| category2supercategory | |
| ) | |
| else: | |
| sys.exit(0) | |
| if __name__=='__main__': | |
| base = "https://serpapi.com/search.json" | |
| engine = 'google' | |
| google_domain = 'google.com.tw' | |
| gl = 'tw' | |
| lr = 'lang_zh-TW' | |
| n_processes = 4 | |
| client = OpenAI( organization = ORGANIZATION_ID) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default='config/config.yml', help="Path to the configuration file") | |
| parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv") | |
| parser.add_argument("--missing_data_path", type=str, default="data/missing/missing.csv") | |
| parser.add_argument("--task", type=str, default="new", choices = ["new", "continue"], help="new or continue") | |
| parser.add_argument("--steps", type=str, default="all", choices = ["all", "crawl", "extract", "regularize", "postprocess"], help="new or continue") | |
| parser.add_argument("--output_dir", type=str, help='output directory') | |
| parser.add_argument("--output_missing_dir", type=str, help='output missing directory') | |
| parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib") | |
| parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib") | |
| parser.add_argument("--crawled_file_path", type=str, default="crawled_results.joblib") | |
| parser.add_argument("--combined_file_path", type=str, default="combined_results.joblib") | |
| parser.add_argument("--regularized_file_path", type=str, default="regularized_results.joblib") | |
| parser.add_argument("--postprocessed_file_path", type=str, default="postprocessed_results.csv") | |
| parser.add_argument("--formatted_results_path", type=str, default="formatted_results.csv") | |
| parser.add_argument("--filtered_file_path", type=str, default="filtered_results.csv") | |
| # parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐']) | |
| parser.add_argument("--classes", type=list, default=['小吃店','日式料理(含居酒屋,串燒)','火(鍋/爐)','東南亞料理(不含日韓)','海鮮熱炒','特色餐廳(含雞、鵝、牛、羊肉)','釣蝦場','傳統餐廳','燒烤','韓式料理(含火鍋,烤肉)','PUB(Live Band)','PUB(一般,含Lounge)','PUB(電音\舞場)','五星級飯店','自助KTV(含連鎖,庭園自助)','西餐廳(含美式,義式,墨式)','咖啡廳(泡沫紅茶)','飯店(星級/旅館,不含五星級)','運動休閒館(含球類練習場,飛鏢等)','西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)','西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)','早餐'] ) | |
| # `小吃店`,`日式料理(含居酒屋,串燒)`,`火(鍋/爐)`,`東南亞料理(不含日韓)`,`海鮮熱炒`,`特色餐廳(含雞、鵝、牛、羊肉)`,`釣蝦場`,`傳統餐廳`,`燒烤`,`韓式料理(含火鍋,烤肉)`,`PUB(Live Band)`,`PUB(一般,含Lounge)`,`PUB(電音\舞場)`,`五星級飯店`,`自助KTV(含連鎖,庭園自助)`,`西餐廳(含美式,義式,墨式)`,`咖啡廳(泡沫紅茶)`,`飯店(星級/旅館,不含五星級)`,`運動休閒館(含球類練習場,飛鏢等)`,`西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)`,`西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)`,`早餐` | |
| parser.add_argument("--backup_classes", type=list, default=['中式', '西式']) | |
| parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch']) | |
| parser.add_argument("--filter_provider", type=str, default='google', choices=['google', 'openai', 'anthropic']) | |
| parser.add_argument("--filter_model", type=str, default='gemini-1.5-flash', choices=[ 'claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'gpt-4o', 'gpt-4o-mini', 'gemini-1.5-flash']) | |
| parser.add_argument("--extraction_provider", type=str, default='openai', choices=['google', 'openai', 'anthropic']) | |
| parser.add_argument("--extraction_model", type=str, default='gpt-3.5-turbo-0125', choices=[ 'claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'gpt-4o', 'gpt-4o-mini', 'gemini-1.5-flash']) | |
| parser.add_argument("--regularization_provider", type=str, default='google', choices=['google', 'openai', 'anthropic']) | |
| parser.add_argument("--regularization_model", type=str, default='gemini-1.5-flash', choices=['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'gpt-4o', 'gpt-4o-mini', 'gemini-1.5-flash']) | |
| parser.add_argument("--serp_provider", type=str, default='serp', choices=['serp', 'serper']) | |
| parser.add_argument("--n_processes", type=int, default=4) | |
| args = parser.parse_args() | |
| config = yaml.safe_load(open(args.config,"r").read()) | |
| category2supercategory = config['category2supercategory'] | |
| supercategory2category = reverse_category2supercategory(category2supercategory) | |
| if args.task == 'new': | |
| main(args) | |
| elif args.task == 'continue': | |
| continue_missing(args) | |
| else: | |
| raise Exception(f"Task {args.task} not implemented") | |