Spaces:
Runtime error
Runtime error
Commit
·
948e91c
1
Parent(s):
60274d1
add missing data workflow
Browse files
model.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
-
import json
|
| 3 |
import argparse
|
|
|
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import anthropic
|
|
@@ -10,7 +10,7 @@ from utils import parse_json_garbage
|
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
-
def llm( provider, model, system_prompt, user_content):
|
| 14 |
"""Invoke LLM service
|
| 15 |
Argument
|
| 16 |
--------
|
|
@@ -26,6 +26,9 @@ def llm( provider, model, system_prompt, user_content):
|
|
| 26 |
------
|
| 27 |
response: str
|
| 28 |
"""
|
|
|
|
|
|
|
|
|
|
| 29 |
if provider=='openai':
|
| 30 |
client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
|
| 31 |
chat_completion = client.chat.completions.create(
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import argparse
|
| 3 |
+
import time
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import anthropic
|
|
|
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
+
def llm( provider, model, system_prompt, user_content, delay:int = 10):
|
| 14 |
"""Invoke LLM service
|
| 15 |
Argument
|
| 16 |
--------
|
|
|
|
| 26 |
------
|
| 27 |
response: str
|
| 28 |
"""
|
| 29 |
+
if delay:
|
| 30 |
+
time.sleep(delay)
|
| 31 |
+
|
| 32 |
if provider=='openai':
|
| 33 |
client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
|
| 34 |
chat_completion = client.chat.completions.create(
|
sheet.py
CHANGED
|
@@ -165,7 +165,7 @@ def classify_results(
|
|
| 165 |
label = parse_json_garbage(pred_cls)['category']
|
| 166 |
labels.append(label)
|
| 167 |
except Exception as e:
|
| 168 |
-
print(f"# CLASSIFICATION error -> evidence: {
|
| 169 |
labels.append("")
|
| 170 |
empty_indices.append(idx)
|
| 171 |
|
|
@@ -488,10 +488,58 @@ def split_dataframe( df: pd.DataFrame, n_processes: int = 4) -> list:
|
|
| 488 |
n_per_process = math.ceil(n / n_processes)
|
| 489 |
return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)]
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
def main(args):
|
| 492 |
"""
|
| 493 |
Argument
|
| 494 |
args: argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
"""
|
| 496 |
crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path)
|
| 497 |
extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path)
|
|
@@ -501,11 +549,11 @@ def main(args):
|
|
| 501 |
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
|
| 502 |
|
| 503 |
## 讀取資料名單 ##
|
| 504 |
-
data = get_leads(args.data_path)
|
| 505 |
|
| 506 |
## 進行爬蟲與分析 ##
|
| 507 |
crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes)
|
| 508 |
-
crawled_results = { k:v[-5:] for k,v in crawled_results.items()}
|
| 509 |
|
| 510 |
## 方法 1: 擷取關鍵資訊與分類 ##
|
| 511 |
extracted_results = extract_results_mp(
|
|
@@ -596,6 +644,7 @@ if __name__=='__main__':
|
|
| 596 |
|
| 597 |
parser = argparse.ArgumentParser()
|
| 598 |
parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv")
|
|
|
|
| 599 |
parser.add_argument("--output_dir", type=str, help='output directory')
|
| 600 |
parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib")
|
| 601 |
parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib")
|
|
@@ -606,9 +655,16 @@ if __name__=='__main__':
|
|
| 606 |
parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐'])
|
| 607 |
parser.add_argument("--backup_classes", type=list, default=['中式', '西式'])
|
| 608 |
parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch'])
|
| 609 |
-
parser.add_argument("--provider", type=str, default='
|
| 610 |
-
parser.add_argument("--model", type=str, default='
|
| 611 |
parser.add_argument("--n_processes", type=int, default=4)
|
| 612 |
args = parser.parse_args()
|
| 613 |
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
label = parse_json_garbage(pred_cls)['category']
|
| 166 |
labels.append(label)
|
| 167 |
except Exception as e:
|
| 168 |
+
print(f"# CLASSIFICATION error: e -> {e}, user_content -> {user_content}, evidence: {evidence}")
|
| 169 |
labels.append("")
|
| 170 |
empty_indices.append(idx)
|
| 171 |
|
|
|
|
| 488 |
n_per_process = math.ceil(n / n_processes)
|
| 489 |
return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)]
|
| 490 |
|
| 491 |
+
|
| 492 |
+
def continue_missing(args):
|
| 493 |
+
"""
|
| 494 |
+
"""
|
| 495 |
+
data = get_leads(args.data_path)
|
| 496 |
+
n_data = data.shape[0]
|
| 497 |
+
|
| 498 |
+
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
|
| 499 |
+
formatted_results = pd.read_csv(formatted_results_path)
|
| 500 |
+
missing_indices = []
|
| 501 |
+
for i in range(n_data):
|
| 502 |
+
if i not in formatted_results['index'].unique():
|
| 503 |
+
print(f"{i} is not found")
|
| 504 |
+
missing_indices.append(i)
|
| 505 |
+
|
| 506 |
+
crawled_results_path = os.path.join( args.output_dir, args.crawled_file_path)
|
| 507 |
+
crawled_results = joblib.load( open( crawled_results_path, "rb"))
|
| 508 |
+
crawled_results = crawled_results['crawled_results'].query( f"index in {missing_indices}")
|
| 509 |
+
print( crawled_results)
|
| 510 |
+
|
| 511 |
+
er = extract_results( crawled_results, classes = args.classes, provider = args.provider, model = args.model)
|
| 512 |
+
er = er['extracted_results']
|
| 513 |
+
print(er['category'])
|
| 514 |
+
|
| 515 |
+
postprossed_results = postprocess_result(
|
| 516 |
+
er,
|
| 517 |
+
"/tmp/postprocessed_results.joblib",
|
| 518 |
+
category2supercategory
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
out_formatted_results = format_output(
|
| 522 |
+
postprossed_results,
|
| 523 |
+
input_column = 'evidence',
|
| 524 |
+
output_column = 'formatted_evidence',
|
| 525 |
+
format_func = format_evidence
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
out_formatted_results.to_csv( "/tmp/formatted_results.missing.csv", index=False)
|
| 529 |
+
formatted_results = pd.concat([formatted_results, out_formatted_results], ignore_index=True)
|
| 530 |
+
formatted_results.sort_values(by='index', ascending=True, inplace=True)
|
| 531 |
+
formatted_results.to_csv( "/tmp/formatted_results.csv", index=False)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
def main(args):
|
| 535 |
"""
|
| 536 |
Argument
|
| 537 |
args: argparse
|
| 538 |
+
Note
|
| 539 |
+
200 records
|
| 540 |
+
crawl: 585.3285548686981
|
| 541 |
+
extract: 2791.631685256958(delay = 10)
|
| 542 |
+
classify: 2374.4915606975555(delay = 10)
|
| 543 |
"""
|
| 544 |
crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path)
|
| 545 |
extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path)
|
|
|
|
| 549 |
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
|
| 550 |
|
| 551 |
## 讀取資料名單 ##
|
| 552 |
+
data = get_leads(args.data_path)
|
| 553 |
|
| 554 |
## 進行爬蟲與分析 ##
|
| 555 |
crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes)
|
| 556 |
+
# crawled_results = { k:v[-5:] for k,v in crawled_results.items()}
|
| 557 |
|
| 558 |
## 方法 1: 擷取關鍵資訊與分類 ##
|
| 559 |
extracted_results = extract_results_mp(
|
|
|
|
| 644 |
|
| 645 |
parser = argparse.ArgumentParser()
|
| 646 |
parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv")
|
| 647 |
+
parser.add_argument("--task", type=str, default="new", choices = ["new", "continue"], help="new or continue")
|
| 648 |
parser.add_argument("--output_dir", type=str, help='output directory')
|
| 649 |
parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib")
|
| 650 |
parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib")
|
|
|
|
| 655 |
parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐'])
|
| 656 |
parser.add_argument("--backup_classes", type=list, default=['中式', '西式'])
|
| 657 |
parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch'])
|
| 658 |
+
parser.add_argument("--provider", type=str, default='openai', choices=['openai', 'anthropic'])
|
| 659 |
+
parser.add_argument("--model", type=str, default='gpt-4-0125-preview', choices=['claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview'])
|
| 660 |
parser.add_argument("--n_processes", type=int, default=4)
|
| 661 |
args = parser.parse_args()
|
| 662 |
|
| 663 |
+
if args.task == 'new':
|
| 664 |
+
main(args)
|
| 665 |
+
elif args.task == 'continue':
|
| 666 |
+
continue_missing(args)
|
| 667 |
+
else:
|
| 668 |
+
raise Exception(f"Task {args.task} not implemented")
|
| 669 |
+
|
| 670 |
+
|
utils.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
|
| 3 |
def parse_json_garbage(s):
|
| 4 |
s = s[next(idx for idx, c in enumerate(s) if c in "{["):]
|
|
|
|
|
|
|
|
|
|
| 5 |
try:
|
| 6 |
-
return json.loads(s)
|
| 7 |
except json.JSONDecodeError as e:
|
| 8 |
-
return json.loads(s
|
| 9 |
|
|
|
|
| 1 |
+
import re
|
| 2 |
import json
|
| 3 |
|
| 4 |
def parse_json_garbage(s):
|
| 5 |
s = s[next(idx for idx, c in enumerate(s) if c in "{["):]
|
| 6 |
+
print(s)
|
| 7 |
+
s = s[:next(idx for idx, c in enumerate(s) if c in "}]")+1]
|
| 8 |
+
print(s)
|
| 9 |
try:
|
| 10 |
+
return json.loads(re.sub("[//#].*","",s,flags=re.MULTILINE))
|
| 11 |
except json.JSONDecodeError as e:
|
| 12 |
+
return json.loads(re.sub("[//#].*","",s,flags=re.MULTILINE))
|
| 13 |
|