Spaces:
Runtime error
Runtime error
shaocongma
commited on
Commit
·
c42190b
1
Parent(s):
acf8a73
Bug fix: error when abstract is None.
Browse files- api_wrapper.py +13 -5
- auto_backgrounds.py +22 -10
- utils/references.py +6 -2
- worker.py +172 -0
api_wrapper.py
CHANGED
|
@@ -12,18 +12,26 @@ todo:
|
|
| 12 |
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
| 13 |
Change Task status from Running to Failed.
|
| 14 |
'''
|
|
|
|
| 15 |
|
| 16 |
from auto_backgrounds import generate_draft
|
| 17 |
-
import json
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
-
GENERATOR_MAPPING = {"draft": generate_draft}
|
|
|
|
| 21 |
|
| 22 |
def generator_wrapper(path_to_config_json):
|
| 23 |
# Read configuration file and call corresponding function
|
| 24 |
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
| 25 |
config = json.load(f)
|
| 26 |
-
|
| 27 |
-
generator = GENERATOR_MAPPING.get(config["generator"])
|
|
|
|
| 28 |
if generator is None:
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
| 13 |
Change Task status from Running to Failed.
|
| 14 |
'''
|
| 15 |
+
import os.path
|
| 16 |
|
| 17 |
from auto_backgrounds import generate_draft
|
| 18 |
+
import json, time
|
| 19 |
+
from utils.file_operations import make_archive
|
| 20 |
|
| 21 |
|
| 22 |
+
# GENERATOR_MAPPING = {"draft": generate_draft}
|
| 23 |
+
GENERATOR_MAPPING = {"draft": None}
|
| 24 |
|
| 25 |
def generator_wrapper(path_to_config_json):
|
| 26 |
# Read configuration file and call corresponding function
|
| 27 |
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
| 28 |
config = json.load(f)
|
| 29 |
+
print("Configuration:", config)
|
| 30 |
+
# generator = GENERATOR_MAPPING.get(config["generator"])
|
| 31 |
+
generator = None
|
| 32 |
if generator is None:
|
| 33 |
+
# generate a fake ZIP file and upload
|
| 34 |
+
time.sleep(150)
|
| 35 |
+
zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
|
| 36 |
+
return make_archive(path_to_config_json, zip_path)
|
| 37 |
+
|
auto_backgrounds.py
CHANGED
|
@@ -3,7 +3,6 @@ from utils.references import References
|
|
| 3 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
| 4 |
from utils.tex_processing import create_copies
|
| 5 |
from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
|
| 6 |
-
from references_generator import generate_top_k_references
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
|
|
@@ -26,12 +25,14 @@ def log_usage(usage, generating_target, print_out=True):
|
|
| 26 |
TOTAL_PROMPTS_TOKENS += prompts_tokens
|
| 27 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
| 28 |
|
| 29 |
-
message = f"For generating {generating_target}, {total_tokens} tokens have been used
|
|
|
|
| 30 |
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
| 31 |
if print_out:
|
| 32 |
print(message)
|
| 33 |
logging.info(message)
|
| 34 |
|
|
|
|
| 35 |
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
| 36 |
max_kw_refs=10, max_num_refs=50, bib_refs=None, max_tokens=2048):
|
| 37 |
"""
|
|
@@ -44,9 +45,12 @@ def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
|
| 44 |
title (str): The title of the paper.
|
| 45 |
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
| 46 |
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
| 47 |
-
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
| 51 |
|
| 52 |
Returns:
|
|
@@ -111,21 +115,29 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
|
|
| 111 |
def generate_draft(title, description="", template="ICLR2022",
|
| 112 |
tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
|
| 113 |
# pre-processing `sections` parameter;
|
|
|
|
|
|
|
| 114 |
print("================PRE-PROCESSING================")
|
| 115 |
if sections is None:
|
| 116 |
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
| 117 |
|
| 118 |
# todo: add more parameters; select which section to generate; select maximum refs.
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# main components
|
|
|
|
| 122 |
for section in sections:
|
| 123 |
-
print(f"
|
| 124 |
max_attempts = 4
|
| 125 |
attempts_count = 0
|
| 126 |
while attempts_count < max_attempts:
|
| 127 |
try:
|
| 128 |
usage = section_generation(paper, section, destination_folder, model=model)
|
|
|
|
| 129 |
log_usage(usage, section)
|
| 130 |
break
|
| 131 |
except Exception as e:
|
|
@@ -153,7 +165,7 @@ if __name__ == "__main__":
|
|
| 153 |
import openai
|
| 154 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
output = generate_draft(
|
| 159 |
print(output)
|
|
|
|
| 3 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
| 4 |
from utils.tex_processing import create_copies
|
| 5 |
from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
|
|
|
|
| 6 |
import logging
|
| 7 |
import time
|
| 8 |
|
|
|
|
| 25 |
TOTAL_PROMPTS_TOKENS += prompts_tokens
|
| 26 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
| 27 |
|
| 28 |
+
message = f"For generating {generating_target}, {total_tokens} tokens have been used " \
|
| 29 |
+
f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \
|
| 30 |
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
| 31 |
if print_out:
|
| 32 |
print(message)
|
| 33 |
logging.info(message)
|
| 34 |
|
| 35 |
+
|
| 36 |
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
| 37 |
max_kw_refs=10, max_num_refs=50, bib_refs=None, max_tokens=2048):
|
| 38 |
"""
|
|
|
|
| 45 |
title (str): The title of the paper.
|
| 46 |
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
| 47 |
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
| 48 |
+
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
|
| 49 |
+
for the collected papers. Defaults to False.
|
| 50 |
+
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
|
| 51 |
+
Defaults to 10.
|
| 52 |
+
max_num_refs (int, optional): The maximum number of references that can be included in the paper.
|
| 53 |
+
Defaults to 50.
|
| 54 |
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
| 55 |
|
| 56 |
Returns:
|
|
|
|
| 115 |
def generate_draft(title, description="", template="ICLR2022",
|
| 116 |
tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
|
| 117 |
# pre-processing `sections` parameter;
|
| 118 |
+
print("================START================")
|
| 119 |
+
print(f"Generating {title}.")
|
| 120 |
print("================PRE-PROCESSING================")
|
| 121 |
if sections is None:
|
| 122 |
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
| 123 |
|
| 124 |
# todo: add more parameters; select which section to generate; select maximum refs.
|
| 125 |
+
if model == "gpt-4":
|
| 126 |
+
max_tokens = 4096
|
| 127 |
+
else:
|
| 128 |
+
max_tokens = 2048
|
| 129 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, max_num_refs, bib_refs, max_tokens=max_tokens)
|
| 130 |
|
| 131 |
# main components
|
| 132 |
+
print(f"================PROCESSING================")
|
| 133 |
for section in sections:
|
| 134 |
+
print(f"Generate {section} part...")
|
| 135 |
max_attempts = 4
|
| 136 |
attempts_count = 0
|
| 137 |
while attempts_count < max_attempts:
|
| 138 |
try:
|
| 139 |
usage = section_generation(paper, section, destination_folder, model=model)
|
| 140 |
+
print(f"{section} part has been generated. ")
|
| 141 |
log_usage(usage, section)
|
| 142 |
break
|
| 143 |
except Exception as e:
|
|
|
|
| 165 |
import openai
|
| 166 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 167 |
|
| 168 |
+
target_title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
| 169 |
+
target_description = ""
|
| 170 |
+
output = generate_draft(target_title, target_description, tldr=True, max_kw_refs=10)
|
| 171 |
print(output)
|
utils/references.py
CHANGED
|
@@ -334,8 +334,12 @@ class References:
|
|
| 334 |
prompts = {}
|
| 335 |
tokens = 0
|
| 336 |
for paper in result:
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
if tokens >= max_tokens:
|
| 340 |
break
|
| 341 |
return prompts
|
|
|
|
| 334 |
prompts = {}
|
| 335 |
tokens = 0
|
| 336 |
for paper in result:
|
| 337 |
+
abstract = paper.get("abstract")
|
| 338 |
+
if abstract is not None and isinstance(abstract, str):
|
| 339 |
+
prompts[paper["paper_id"]] = paper["abstract"]
|
| 340 |
+
tokens += tiktoken_len(paper["abstract"])
|
| 341 |
+
else:
|
| 342 |
+
prompts[paper["paper_id"]] = " "
|
| 343 |
if tokens >= max_tokens:
|
| 344 |
break
|
| 345 |
return prompts
|
worker.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
This script is only used for service-side host.
|
| 3 |
+
'''
|
| 4 |
+
import boto3
|
| 5 |
+
import os, time
|
| 6 |
+
from api_wrapper import generator_wrapper
|
| 7 |
+
from sqlalchemy import create_engine, Table, MetaData, update, select
|
| 8 |
+
from sqlalchemy.orm import sessionmaker
|
| 9 |
+
from sqlalchemy import inspect
|
| 10 |
+
|
| 11 |
+
QUEUE_URL = os.getenv('QUEUE_URL')
|
| 12 |
+
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
| 13 |
+
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
| 14 |
+
BUCKET_NAME = os.getenv('BUCKET_NAME')
|
| 15 |
+
DB_STRING = os.getenv('DATABASE_STRING')
|
| 16 |
+
|
| 17 |
+
# Create engine
|
| 18 |
+
ENGINE = create_engine(DB_STRING)
|
| 19 |
+
SESSION = sessionmaker(bind=ENGINE)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
#######################################################################################################################
|
| 23 |
+
# Amazon SQS Handler
|
| 24 |
+
#######################################################################################################################
|
| 25 |
+
def get_sqs_client():
|
| 26 |
+
sqs = boto3.client('sqs', region_name="us-east-2",
|
| 27 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
| 28 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
|
| 29 |
+
return sqs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def receive_message():
|
| 33 |
+
sqs = get_sqs_client()
|
| 34 |
+
message = sqs.receive_message(QueueUrl=QUEUE_URL)
|
| 35 |
+
if message.get('Messages') is not None:
|
| 36 |
+
receipt_handle = message['Messages'][0]['ReceiptHandle']
|
| 37 |
+
else:
|
| 38 |
+
receipt_handle = None
|
| 39 |
+
return message, receipt_handle
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def delete_message(receipt_handle):
|
| 43 |
+
sqs = get_sqs_client()
|
| 44 |
+
response = sqs.delete_message(QueueUrl=QUEUE_URL, ReceiptHandle=receipt_handle)
|
| 45 |
+
return response
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
#######################################################################################################################
|
| 49 |
+
# AWS S3 Handler
|
| 50 |
+
#######################################################################################################################
|
| 51 |
+
def get_s3_client():
|
| 52 |
+
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
| 53 |
+
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
| 54 |
+
session = boto3.Session(
|
| 55 |
+
aws_access_key_id=access_key_id,
|
| 56 |
+
aws_secret_access_key=secret_access_key,
|
| 57 |
+
)
|
| 58 |
+
s3 = session.resource('s3')
|
| 59 |
+
bucket = s3.Bucket(BUCKET_NAME)
|
| 60 |
+
return s3, bucket
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def upload_file(file_name, target_name=None):
|
| 64 |
+
s3, _ = get_s3_client()
|
| 65 |
+
|
| 66 |
+
if target_name is None:
|
| 67 |
+
target_name = file_name
|
| 68 |
+
s3.meta.client.upload_file(Filename=file_name, Bucket=BUCKET_NAME, Key=target_name)
|
| 69 |
+
print(f"The file {file_name} has been uploaded!")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def download_file(file_name):
|
| 73 |
+
""" Download `file_name` from the bucket.
|
| 74 |
+
Bucket (str) – The name of the bucket to download from.
|
| 75 |
+
Key (str) – The name of the key to download from.
|
| 76 |
+
Filename (str) – The path to the file to download to.
|
| 77 |
+
"""
|
| 78 |
+
s3, _ = get_s3_client()
|
| 79 |
+
s3.meta.client.download_file(Bucket=BUCKET_NAME, Key=file_name, Filename=os.path.basename(file_name))
|
| 80 |
+
print(f"The file {file_name} has been downloaded!")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
#######################################################################################################################
|
| 84 |
+
# AWS SQL Handler
|
| 85 |
+
#######################################################################################################################
|
| 86 |
+
def modify_status(task_id, new_status):
|
| 87 |
+
session = SESSION()
|
| 88 |
+
metadata = MetaData()
|
| 89 |
+
task_to_update = task_id
|
| 90 |
+
task_table = Table('task', metadata, autoload_with=ENGINE)
|
| 91 |
+
stmt = select(task_table).where(task_table.c.task_id == task_to_update)
|
| 92 |
+
# Execute the statement
|
| 93 |
+
with ENGINE.connect() as connection:
|
| 94 |
+
result = connection.execute(stmt)
|
| 95 |
+
|
| 96 |
+
# Fetch the first result (if exists)
|
| 97 |
+
task_data = result.fetchone()
|
| 98 |
+
|
| 99 |
+
# If user_data is not None, the user exists and we can update the password
|
| 100 |
+
if task_data:
|
| 101 |
+
# Update statement
|
| 102 |
+
stmt = (
|
| 103 |
+
update(task_table).
|
| 104 |
+
where(task_table.c.task_id == task_to_update).
|
| 105 |
+
values(status=new_status)
|
| 106 |
+
)
|
| 107 |
+
# Execute the statement and commit
|
| 108 |
+
result = connection.execute(stmt)
|
| 109 |
+
connection.commit()
|
| 110 |
+
# Close the session
|
| 111 |
+
session.close()
|
| 112 |
+
|
| 113 |
+
#######################################################################################################################
|
| 114 |
+
# Pipline
|
| 115 |
+
#######################################################################################################################
|
| 116 |
+
def pipeline(message_count=0, query_interval=10):
|
| 117 |
+
# status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
|
| 118 |
+
|
| 119 |
+
# Query a message from SQS
|
| 120 |
+
msg, handle = receive_message()
|
| 121 |
+
if handle is None:
|
| 122 |
+
print("No message in SQS. ")
|
| 123 |
+
time.sleep(query_interval)
|
| 124 |
+
else:
|
| 125 |
+
print("===============================================================================================")
|
| 126 |
+
print(f"MESSAGE COUNT: {message_count}")
|
| 127 |
+
print("===============================================================================================")
|
| 128 |
+
config_s3_path = msg['Messages'][0]['Body']
|
| 129 |
+
config_s3_dir = os.path.dirname(config_s3_path)
|
| 130 |
+
config_local_path = os.path.basename(config_s3_path)
|
| 131 |
+
task_id, _ = os.path.splitext(config_local_path)
|
| 132 |
+
|
| 133 |
+
print("Initializing ...")
|
| 134 |
+
print("Configuration file on S3: ", config_s3_path)
|
| 135 |
+
print("Configuration file on S3 (Directory): ", config_s3_dir)
|
| 136 |
+
print("Local file path: ", config_local_path)
|
| 137 |
+
print("Task id: ", task_id)
|
| 138 |
+
|
| 139 |
+
print(f"Success in receiving message: {msg}")
|
| 140 |
+
print(f"Configuration file path: {config_s3_path}")
|
| 141 |
+
|
| 142 |
+
# Process the downloaded configuration file
|
| 143 |
+
download_file(config_s3_path)
|
| 144 |
+
modify_status(task_id, 1) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
|
| 145 |
+
delete_message(handle)
|
| 146 |
+
print(f"Success in the initialization. Message deleted.")
|
| 147 |
+
|
| 148 |
+
print("Running ...")
|
| 149 |
+
# try:
|
| 150 |
+
zip_path = generator_wrapper(config_local_path)
|
| 151 |
+
# Upload the generated file to S3
|
| 152 |
+
upload_to = os.path.join(config_s3_dir, zip_path).replace("\\", "/")
|
| 153 |
+
|
| 154 |
+
print("Local file path (ZIP): ", zip_path)
|
| 155 |
+
print("Upload to S3: ", upload_to)
|
| 156 |
+
upload_file(zip_path, upload_to)
|
| 157 |
+
modify_status(task_id, 2) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed, 4 - deleted
|
| 158 |
+
print(f"Success in generating the paper.")
|
| 159 |
+
|
| 160 |
+
# Complete.
|
| 161 |
+
print("Task completed.")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def initialize_everything():
|
| 165 |
+
# Clear S3
|
| 166 |
+
|
| 167 |
+
# Clear SQS
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
pipeline()
|