Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import glob | |
| import time | |
| import yaml | |
| import joblib | |
| import argparse | |
| import jinja2 | |
| import anthropic | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from loguru import logger | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
| from data import get_leads | |
| from utils import parse_json_garbage, compose_query | |
| tqdm.pandas() | |
| try: | |
| logger.remove(0) | |
| logger.add(sys.stderr, level="INFO") | |
| except ValueError: | |
| pass | |
| load_dotenv() | |
| def prepare_batch( crawled_result_path: str, config: dict, output_path: str, topn: int = None): | |
| """ | |
| Argument | |
| -------- | |
| crawled_result_path: str | |
| Path to the crawled result file (result from the crawl task) | |
| config: dict | |
| Configuration for the batch job | |
| output_path: str | |
| Path to the output file | |
| Return | |
| ------ | |
| items: list | |
| Example | |
| {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
| {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
| model = model, | |
| response_format = {"type": "json_object"}, | |
| temperature = 0, | |
| max_tokens = 4096, | |
| """ | |
| assert os.path.exists(crawled_result_path), f"File not found: {crawled_result_path}" | |
| crawled_results = joblib.load(open(crawled_result_path, "rb"))['crawled_results'] | |
| if topn: | |
| crawled_results = crawled_results.head(topn) | |
| jenv = jinja2.Environment() | |
| template = jenv.from_string(config['extraction_prompt']) | |
| system_prompt = template.render( classes = config['classes'], traits = config['traits']) | |
| template = jenv.from_string(config['user_content']) | |
| items = [] | |
| for i, d in tqdm(enumerate(crawled_results.itertuples())): | |
| idx = d.index # d[1] | |
| evidence = d.googlemap_results +"\n" + d.search_results | |
| business_id = d.business_id # d[2] | |
| business_name = d.business_name # d[3] | |
| address = d.address # d[7] | |
| ana_res = None | |
| query = compose_query( address, business_name, use_exclude=False) | |
| user_content = template.render( query = query, search_results = evidence) | |
| item = { | |
| "custom_id": str(business_id), | |
| "method": "POST", | |
| "url": "/v1/chat/completions", | |
| "body": { | |
| "model": config['model'], | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content} | |
| ], | |
| "max_tokens": config['max_tokens'], | |
| "temperature": config['temperature'], | |
| "response_format": {"type": "json_object"}, | |
| } | |
| } | |
| items.append( json.dumps(item, ensure_ascii=False)) | |
| with open(output_path, "w") as f: | |
| for item in items: | |
| f.write(item + "\n") | |
| def prepare_regularization( extracted_result_path: str, config: dict, output_path: str, topn: int = None): | |
| """ | |
| Argument | |
| -------- | |
| extracted_file_path: str | |
| Path to the extracted result file (result from the extraction task) | |
| config: dict | |
| Configuration for the batch job | |
| output_path: str | |
| Path to the output file | |
| topn: int | |
| Number of records to be processed | |
| Return | |
| ------ | |
| items: list | |
| Example | |
| {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
| {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
| model = model, | |
| response_format = {"type": "json_object"}, | |
| temperature = 0, | |
| max_tokens = 4096, | |
| """ | |
| assert os.path.exists(extracted_result_path), f"File not found: {extracted_result_path}" | |
| extracted_results = joblib.load(open(extracted_result_path, "rb"))['extracted_results'] | |
| if topn: | |
| extracted_results = extracted_results.head(topn) | |
| jenv = jinja2.Environment() | |
| template = jenv.from_string(config['regularization_prompt']) | |
| system_prompt = template.render() | |
| template = jenv.from_string(config['regularization_user_content']) | |
| items = [] | |
| for i, d in tqdm(enumerate(extracted_results.itertuples())): | |
| idx = d.index # d[1] | |
| category = d.category | |
| business_id = d.business_id | |
| if pd.isna(category) or len(category)==0: | |
| category = "" | |
| user_content = template.render( category = category) | |
| item = { | |
| "custom_id": str(business_id), | |
| "method": "POST", | |
| "url": "/v1/chat/completions", | |
| "body": { | |
| "model": config['model'], | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content} | |
| ], | |
| "max_tokens": config['max_tokens'], | |
| "temperature": config['temperature'], | |
| "response_format": {"type": "json_object"}, | |
| } | |
| } | |
| items.append( json.dumps(item, ensure_ascii=False)) | |
| with open(output_path, "w") as f: | |
| for item in items: | |
| f.write(item + "\n") | |
| def run_batch( input_path: str, job_path: str, jsonl_path: str): | |
| """ | |
| Argument | |
| -------- | |
| input_path: str | |
| Path to the prepared batch input file (result from prepare_batch) | |
| job_path: str | |
| Path to the job file (response from creating a batch job) | |
| jsonl_path: str | |
| Path to the output file | |
| extracted_result_path: str | |
| Path to the extracted result file | |
| """ | |
| assert os.path.exists(input_path), f"File not found: {input_path}" | |
| st = time.time() | |
| client = OpenAI( organization = os.getenv('ORGANIZATION_ID')) | |
| batch_input_file = client.files.create( | |
| file=open( input_path, "rb"), | |
| purpose="batch" | |
| ) | |
| batch_input_file_id = batch_input_file.id | |
| logger.info(f"batch_input_file_id -> {batch_input_file_id}") | |
| batch_resp = client.batches.create( | |
| input_file_id=batch_input_file_id, | |
| endpoint="/v1/chat/completions", | |
| completion_window="24h", | |
| metadata={ | |
| "description": "batch job" | |
| } | |
| ) | |
| logger.info(f"batch resp -> {batch_resp}") | |
| try: | |
| with open( job_path, "wb") as f: | |
| joblib.dump(batch_resp, f) | |
| except Exception as e: | |
| logger.error(f"Error -> {e}") | |
| with open("./job.joblib", "wb") as f: | |
| joblib.dump(batch_resp, f) | |
| is_ready = False | |
| while 1: | |
| batch_resp = client.batches.retrieve(batch_resp.id) | |
| if batch_resp.status == 'validating': | |
| logger.info("the input file is being validated before the batch can begin") | |
| elif batch_resp.status == 'failed': | |
| logger.info("the input file has failed the validation process") | |
| break | |
| elif batch_resp.status == 'in_progress': | |
| logger.info("the input file was successfully validated and the batch is currently being ru") | |
| elif batch_resp.status == 'finalizing': | |
| logger.info("the batch has completed and the results are being prepared") | |
| elif batch_resp.status == 'completed': | |
| logger.info("the batch has been completed and the results are ready") | |
| is_ready = True | |
| break | |
| elif batch_resp.status == 'expired': | |
| logger.info("the batch was not able to be completed within the 24-hour time window") | |
| break | |
| elif batch_resp.status == 'cancelling': | |
| logger.info("the batch is being cancelled (may take up to 10 minutes)") | |
| elif batch_resp.status == 'cancelled': | |
| logger.info("the batch was cancelled") | |
| break | |
| else: | |
| raise logger.error("Invalid status") | |
| time.sleep(10) | |
| if is_ready: | |
| output_resp = client.files.content(batch_resp.output_file_id) | |
| llm_results = [] | |
| try: | |
| with open(jsonl_path, "w") as f: | |
| for line in output_resp.content.decode('utf-8').split("\n"): | |
| line = line.strip() | |
| if len(line)==0: | |
| break | |
| llm_results.append(line) | |
| f.write(f"{line}\n") | |
| except Exception as e: | |
| logger.error(f"Error -> {e}") | |
| with open("./output.jsonl", "w") as f: | |
| for line in output_resp.content.decode('utf-8').split("\n"): | |
| line = line.strip() | |
| if len(line)==0: | |
| break | |
| llm_results.append(line) | |
| f.write(f"{line}\n") | |
| print( f"Time elapsed: {time.time()-st:.2f} seconds") | |
| def batch2extract( jsonl_path: str, crawled_result_path: str, extracted_result_path: str): | |
| """ | |
| Argument | |
| -------- | |
| jsonl_path: str | |
| Path to the batch output file | |
| crawled_result_path: str | |
| Path to the crawled result file (result from the crawl task) | |
| extracted_result_path: str | |
| Path to the extracted result file | |
| """ | |
| assert os.path.exists(jsonl_path), f"File not found: {jsonl_path}" | |
| assert os.path.exists(crawled_result_path), f"File not found: {crawled_result_path}" | |
| crawled_results = joblib.load(open(crawled_result_path, "rb")) | |
| extracted_results = [] | |
| empty_indices = [] | |
| llm_results = [] | |
| for line in open(jsonl_path, "r"): | |
| line = line.strip() | |
| if len(line)==0: | |
| break | |
| llm_results.append(line) | |
| for i,llm_result in enumerate(llm_results): | |
| try: | |
| llm_result = json.loads(llm_result) | |
| business_id = llm_result['custom_id'] | |
| llm_result = llm_result['response']['body']['choices'][0]['message']['content'] | |
| llm_result = parse_json_garbage(llm_result) | |
| llm_result['business_id'] = business_id | |
| extracted_results.append(llm_result) | |
| except Exception as e: | |
| logger.error(f"Error -> {e}, llm_result -> {llm_result}") | |
| empty_indices.append(i) | |
| extracted_results = pd.DataFrame(extracted_results) | |
| basic_info = [] | |
| for i, d in tqdm(enumerate(crawled_results['crawled_results'].itertuples())): | |
| idx = d.index # d[1] | |
| evidence = d.googlemap_results +"\n" + d.search_results | |
| business_id = d.business_id # d[2] | |
| business_name = d.business_name # d[3] | |
| address = d.address # d[7] | |
| # ana_res = None | |
| # query = compose_query( address, business_name, use_exclude=False) | |
| basic_info.append( { | |
| "index": idx, | |
| "business_id": business_id, | |
| "business_name": business_name, | |
| "evidence": evidence, | |
| # ** ext_res | |
| } ) | |
| basic_info = pd.DataFrame(basic_info) | |
| extracted_results = basic_info.astype({"business_id": str}).merge(extracted_results, on="business_id", how="inner") | |
| print( f"{ extracted_results.shape[0]} records merged.") | |
| extracted_results = {"extracted_results": extracted_results, "empty_indices": empty_indices} | |
| with open(extracted_result_path, "wb") as f: | |
| joblib.dump(extracted_results, f) | |
| def batch2reg( jsonl_path: str, extracted_result_path: str, regularized_result_path: str): | |
| """ | |
| Argument | |
| -------- | |
| jsonl_path: str | |
| Path to the batch output file | |
| extracted_result_path: str | |
| Path to the extracted result file | |
| regularized_result_path: str | |
| Path to the regularization result file | |
| """ | |
| assert os.path.exists(jsonl_path), f"File not found: {jsonl_path}" | |
| assert os.path.exists(extracted_result_path), f"File not found: {extracted_result_path}" | |
| extracted_results = joblib.load(open(extracted_result_path, "rb"))['extracted_results'] | |
| llm_results, regularized_results, empty_indices = [], [], [] | |
| for line in open(jsonl_path, "r"): | |
| line = line.strip() | |
| if len(line)==0: | |
| break | |
| llm_results.append(line) | |
| for i,llm_result in enumerate(llm_results): | |
| try: | |
| llm_result = json.loads(llm_result) | |
| business_id = llm_result['custom_id'] | |
| llm_result = llm_result['response']['body']['choices'][0]['message']['content'] | |
| llm_result = parse_json_garbage(llm_result) | |
| llm_result['business_id'] = business_id | |
| regularized_results.append(llm_result) | |
| except Exception as e: | |
| logger.error(f"Error -> {e}, llm_result -> {llm_result}") | |
| empty_indices.append(i) | |
| regularized_results = pd.DataFrame(regularized_results) | |
| basic_info = [] | |
| for i, d in tqdm(enumerate(extracted_results.itertuples())): | |
| idx = d.index # d[1] | |
| # evidence = d.googlemap_results +"\n" + d.search_results | |
| evidence = d.evidence | |
| business_id = d.business_id # d[2] | |
| business_name = d.business_name # d[3] | |
| # address = d.address # d[7] | |
| # ana_res = None | |
| # query = compose_query( address, business_name, use_exclude=False) | |
| basic_info.append( { | |
| "index": idx, | |
| "business_id": business_id, | |
| "business_name": business_name, | |
| "evidence": evidence, | |
| # ** ext_res | |
| } ) | |
| basic_info = pd.DataFrame(basic_info) | |
| regularized_results = basic_info.astype({"business_id": str}).merge(regularized_results, on="business_id", how="inner") | |
| print( f"{ regularized_results.shape[0]} records merged.") | |
| regularized_results = {"regularized_results": regularized_results, "empty_indices": empty_indices} | |
| with open(regularized_result_path, "wb") as f: | |
| joblib.dump(regularized_results, f) | |
| def postprocess_result( config: dict, regularized_result_path: str, postprocessed_result_path, category_hierarchy: dict, column_name: str = 'category') -> pd.DataFrame: | |
| """ | |
| Argument | |
| config: dict | |
| regularized_results_path: str | |
| analysis_result: `evidence`, `result` | |
| postprocessed_results_path | |
| Return | |
| """ | |
| assert os.path.exists(regularized_result_path), f"File not found: {regularized_result_path}" | |
| regularized_results = joblib.load(open(regularized_result_path, "rb"))['regularized_results'] | |
| if True: | |
| # if not os.path.exists(postprocessed_result_path): | |
| postprocessed_results = regularized_results.copy() | |
| postprocessed_results.loc[ :, "category"] = postprocessed_results[column_name].progress_apply(lambda x: "" if x not in category_hierarchy else x) | |
| postprocessed_results['supercategory'] = postprocessed_results[column_name].progress_apply(lambda x: category_hierarchy.get(x, '')) | |
| # with open( postprocessed_results_path, "wb") as f: | |
| # joblib.dump( postprocessed_results, f) | |
| postprocessed_results.to_csv( postprocessed_result_path, index=False) | |
| else: | |
| # with open( postprocessed_results_path, "rb") as f: | |
| # postprocessed_results = joblib.load(f) | |
| postprocessed_results = pd.read_csv( postprocessed_result_path) | |
| return postprocessed_results | |
| def combine_postprocessed_results( config: dict, input_path: str, postprocessed_result_path: str, reference_path: str, output_path: str): | |
| """ | |
| Argument | |
| config: dict | |
| input_path: str | |
| postprocessed_result_path: str | |
| reference_path: str | |
| output_path: str | |
| """ | |
| file_pattern = str(Path(input_path).joinpath( postprocessed_result_path, "postprocessed_results.csv")) | |
| logger.info(f"file_pattern -> {file_pattern}") | |
| file_paths = list(glob.glob(file_pattern)) | |
| assert len(file_paths)>0, f"File not found: {postprocessed_result_path}" | |
| postprocessed_results = pd.concat([pd.read_csv(file_path, dtype={"business_id": str}) for file_path in file_paths], axis=0) | |
| reference_results = get_leads( reference_path) | |
| # reference_results = reference_results.rename(config['column_mapping'], axis=1) | |
| postprocessed_results = reference_results.merge( postprocessed_results, left_on = "統一編號", right_on="business_id", how="left") | |
| postprocessed_results.to_csv( output_path, index=False) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( "-c", "--config", type=str, default='config/config.yml', help="Path to the configuration file") | |
| parser.add_argument( "-t", "--task", type=str, default='prepare_batch', choices=['prepare_batch', 'prepare_regularization', 'run_batch', 'batch2extract', 'batch2reg', 'postprocess', 'combine']) | |
| parser.add_argument( "-i", "--input_path", type=str, default='', ) | |
| parser.add_argument( "-o", "--output_path", type=str, default='', ) | |
| parser.add_argument( "-b", "--batch_path", type=str, default='', ) | |
| parser.add_argument( "-j", "--job_path", type=str, default='', ) | |
| parser.add_argument( "-jp", "--jsonl_path", type=str, default='', ) | |
| parser.add_argument( "-crp", "--crawled_result_path", type=str, default='', ) | |
| parser.add_argument( "-erp", "--extracted_result_path", type=str, default='', ) | |
| parser.add_argument( "-rrp", "--regularized_result_path", type=str, default='', ) | |
| parser.add_argument( "-prp", "--postprocessed_result_path", type=str, default='', ) | |
| parser.add_argument( "-rp", "--reference_path", type=str, default='', ) | |
| parser.add_argument( "-topn", "--topn", type=int, default=None ) | |
| args = parser.parse_args() | |
| # classes = ['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', ] | |
| # backup_classes = [ '中式', '西式'] | |
| assert os.path.exists(args.config), f"File not found: {args.config}" | |
| config = yaml.safe_load(open(args.config, "r").read()) | |
| if args.task == 'prepare_batch': | |
| prepare_batch( crawled_result_path = args.crawled_result_path, config = config, output_path = args.output_path, topn = args.topn) | |
| elif args.task == 'run_batch': | |
| run_batch( input_path = args.input_path, job_path = args.job_path, jsonl_path = args.jsonl_path) | |
| elif args.task == 'prepare_regularization': | |
| prepare_regularization( extracted_result_path = args.extracted_result_path, config = config, output_path = args.output_path, topn = args.topn) | |
| elif args.task == 'batch2extract': | |
| batch2extract( | |
| jsonl_path = args.jsonl_path, | |
| crawled_result_path = args.crawled_result_path, | |
| extracted_result_path = args.extracted_result_path | |
| ) | |
| elif args.task == 'batch2reg': | |
| batch2reg( | |
| jsonl_path = args.jsonl_path, | |
| extracted_result_path = args.extracted_result_path, | |
| regularized_result_path = args.regularized_result_path | |
| ) | |
| elif args.task == 'postprocess': | |
| postprocess_result( | |
| config = config, | |
| regularized_result_path = args.regularized_result_path, | |
| postprocessed_result_path = args.postprocessed_result_path, | |
| category_hierarchy = config['category2supercategory'], | |
| column_name = 'category' | |
| ) | |
| elif args.task == 'combine': | |
| combine_postprocessed_results( | |
| config, | |
| args.input_path, | |
| args.postprocessed_result_path, | |
| args.reference_path, | |
| args.output_path | |
| ) | |
| else: | |
| raise Exception("Invalid task") | |