youngtsai's picture
if model_name in ["gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]:
df2624f
import gradio as gr
import os
import pandas as pd
from openai import OpenAI
import json
from google.cloud import aiplatform
from vertexai.preview.generative_models import GenerativeModel
from google.oauth2.service_account import Credentials
# 設置 OpenAI API 客戶端
IS_ENV_LOCAL = os.getenv("IS_ENV_LOCAL", False)
if IS_ENV_LOCAL:
local_json = json.load(open("local.json"))
openai_api_key = local_json["OPENAI_API_KEY"]
GOOGLE_SERVICE_ACCOUNT_INFO = local_json["GBQ_TOKEN"]
google_service_account_info_dict = GOOGLE_SERVICE_ACCOUNT_INFO
else:
openai_api_key = os.getenv("OPENAI_API_KEY")
GOOGLE_SERVICE_ACCOUNT_INFO = os.getenv("GBQ_TOKEN")
google_service_account_info_dict = json.loads(GOOGLE_SERVICE_ACCOUNT_INFO)
# OPENAI
OPENAI_CLIENT = OpenAI(api_key=openai_api_key)
# GOOGLE
GOOGPE_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
google_creds = Credentials.from_service_account_info(
google_service_account_info_dict, scopes=GOOGPE_SCOPES
)
aiplatform.init(
project="junyiacademy",
service_account=google_service_account_info_dict,
credentials=google_creds,
)
GEMINI_MODEL = GenerativeModel("gemini-pro")
def extract_article_from_content(article_text):
start_markers = ["新文章:", "New Article:", "Here it is:"]
end_marker = "\nThank you"
for start_marker in start_markers:
start_index = article_text.find(start_marker)
if start_index != -1:
start_index += len(start_marker)
while article_text[start_index] in "\n":
start_index += 1
end_index = article_text.find(end_marker, start_index)
if end_index != -1:
return article_text[start_index:end_index].rstrip()
return article_text[start_index:].rstrip()
return article_text
def validate_article(generated_article, lesson_words, base_chars, original_word_count):
clean_article = "".join(char for char in generated_article if char not in "、,。!?;:「」『』()《》【】'\n'")
not_every_new_word_is_used = not all(word in clean_article for word in [char for char in lesson_words])
word_out_of_range = not set(clean_article).issubset(set(lesson_words + base_chars))
new_word_count = len(clean_article)
word_count_error = not (0.9 * original_word_count <= new_word_count <= 1.1 * original_word_count)
lesson_words_not_in_new_article = [word for word in [char for char in lesson_words] if word not in clean_article]
words_not_in_both = [word for word in lesson_words_not_in_new_article if word not in base_chars]
additional_words = set([word for word in [char for char in clean_article] if word not in [char for char in lesson_words] and word not in [char for char in base_chars]])
count_of_words_in_new_article = len(clean_article)
result = {
"not_every_new_word_is_used": not_every_new_word_is_used,
"word_out_of_range": word_out_of_range,
"word_count_error": word_count_error,
"lesson_words_not_in_new_article": lesson_words_not_in_new_article,
"words_not_in_both": words_not_in_both,
"additional_words": additional_words,
"count_of_words_in_new_article": count_of_words_in_new_article
}
return result
def generate_new_article(lesson_words, original_article, original_word_count, base_chars, model_name):
# check lesson_words, original_article, original_word_count exist
if not lesson_words or not original_article or not original_word_count:
raise gr.Error("lesson_words, original_article, original_word_count are required. Please upload the lesson csv file.")
attempt = 0
max_attempts = 3
generated_article = ""
while attempt < max_attempts:
attempt += 1
print("================Attempt=====================")
print(f"Attempt {attempt} to generate new article")
print("===========================================")
system_prompt = "You are a creative writer specialized in Chinese Children book. You will help me write Chinese Articles."
prompt = f"""
Please write a new and original Chinese article tailored for first-grade students. Here's a summary of the key points that you should follow:
Use Traditional Chinese (ZH-TW) Characters: The article should be written in Traditional Chinese(ZH-TW), not Simplified Chinese.
Adherence to the Original Article: The new creation should closely follow the spirit, style, and rhythmic pattern of the provided original article. The number of words, excluding punctuation marks, should be similar to that of the original, approximately {original_word_count} words.
Incorporate "New Words": Every word listed under "new words" must be used in the article. These words are: {lesson_words}.
Utilize the "Word Library": Additional words required for the article can be selected from the provided "word library," which includes: {base_chars}.
Restriction on Vocabulary: Do not use any words outside the "new words" or the "word library".
Originality: The new article must be unique and original, not a copy of the original work.
"Original Article" for Reference: The example provided is {original_article}。This article serves as a model for the spirit, style, and rhythmic pattern to be emulated.
"""
if model_name in ["gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]:
response = OPENAI_CLIENT.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
max_tokens=1000
)
generated_text = response.choices[0].message.content.strip()
elif model_name == "gemini-pro":
model_response = GEMINI_MODEL.generate_content(
f"{system_prompt}, {prompt}"
)
generated_text = model_response.candidates[0].content.parts[0].text
generated_article = extract_article_from_content(generated_text)
validate_article_result = validate_article(generated_article, lesson_words, base_chars, original_word_count)
not_every_new_word_is_used = validate_article_result['not_every_new_word_is_used']
word_out_of_range = validate_article_result['word_out_of_range']
word_count_error = validate_article_result['word_count_error']
count_of_words_in_new_article = validate_article_result['count_of_words_in_new_article']
print("====validate_article====")
print(f"not_every_new_word_is_used: {not_every_new_word_is_used}")
print(f"word_out_of_range: {word_out_of_range}")
print(f"word_count_error: {word_count_error}")
print("=========================")
if not not_every_new_word_is_used and not word_out_of_range and not word_count_error:
print("Generated article is valid")
break
else:
print("Generated article is invalid")
error_messages = []
if not_every_new_word_is_used:
error_messages.append("Not every new word is used in the article.")
if word_out_of_range:
error_messages.append("The article contains words that are not in the new words or word library.")
if word_count_error:
error_messages.append(f"The word count of the new article deviates more than 10% from the original ({original_word_count}).")
error_messages_str = "\n".join(error_messages) + "\n" # Append the error messages to the prompt for the next attempt
prompt += f"""
The new article is {generated_article}.
word_count is {count_of_words_in_new_article}.
But the generated article is invalid. The following issues were found:
{error_messages_str}
please follow the summary of the key points and fix the errors to generate a new article.
"""
print(f"Prompt for next attempt: {prompt}")
return generated_article, validate_article_result
def load_lesson_csv(file):
try:
df = pd.read_csv(file, encoding='utf-8')
except:
df = pd.read_csv(file.name, encoding='utf-8')
if not df.empty:
first_row = df.iloc[0]
lesson_words = first_row['lesson_words']
original_article = first_row['lesson_article']
clean_original_article = [char for char in original_article if char not in "、,。!?;:「」『』()《》【】'\n''\r'"]
original_word_count = len(clean_original_article)
base_chars = '' # This should be defined or extracted from some column or external source
return lesson_words, original_article, original_word_count, base_chars
return "", "", 0, ""
def load_base_chars_csv(file):
try:
df = pd.read_csv(file, encoding='utf-8')
except:
df = pd.read_csv(file.name, encoding='utf-8')
if not df.empty:
first_row = df.iloc[0]
base_chars = first_row['words']
return base_chars
return ""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("### Original Lesson CSV File")
lesson_csv_file_input = gr.File(label="Upload CSV file (Columns: lesson_words, original_article, original_word_count, base_chars)")
with gr.Column():
gr.Markdown("### Base Characters CSV File")
base_chars_csv_file_input = gr.File(label="Upload Base Characters File")
with gr.Row():
lesson_words_input = gr.Textbox(label="Lesson Words")
original_article_input = gr.Textbox(label="Original Article")
original_word_count_input = gr.Number(label="Original Word Count")
base_chars_input = gr.Textbox(label="Base Characters")
with gr.Row():
model_list = ["gpt-4-0125-preview", "gpt-3.5-turbo", "gpt-4", "gemini-pro"]
with gr.Column():
model_1 = gr.Dropdown(label="Model 1", choices=model_list, value="gpt-4-0125-preview")
generate_button1 = gr.Button("Generate Article - gpt-4-0125-preview")
with gr.Column():
output_text1 = gr.Textbox(label="Generated Article - gpt-4-0125-preview")
with gr.Column():
validate_article_result_1 = gr.JSON()
with gr.Row():
with gr.Column():
model_2 = gr.Dropdown(label="Model 2", choices=model_list, value="gpt-3.5-turbo")
generate_button2 = gr.Button("Generate Article - gpt-3.5-turbo")
with gr.Column():
output_text2 = gr.Textbox(label="Generated Article - gpt-3.5-turbo")
with gr.Column():
validate_article_result_2 = gr.JSON()
with gr.Row():
with gr.Column():
model_3 = gr.Dropdown(label="Model 3", choices=model_list, value="gpt-4")
generate_button3 = gr.Button("Generate Article - gpt-4")
with gr.Column():
output_text3 = gr.Textbox(label="Generated Article - gpt-4")
with gr.Column():
# validate_article_result_3 Json format
validate_article_result_3 = gr.JSON()
with gr.Row():
with gr.Column():
model_4 = gr.Dropdown(label="Model 4", choices=model_list, value="gemini-pro")
generate_button4 = gr.Button("Generate Article - gemini-pro")
with gr.Column():
output_text4 = gr.Textbox(label="Generated Article - gemini-pro")
with gr.Column():
validate_article_result_4 = gr.JSON()
generate_button1.click(
generate_new_article,
inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_1],
outputs=[output_text1, validate_article_result_1]
)
generate_button2.click(
generate_new_article,
inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_2],
outputs=[output_text2, validate_article_result_2]
)
generate_button3.click(
generate_new_article,
inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_3],
outputs=[output_text3, validate_article_result_3]
)
generate_button4.click(
generate_new_article,
inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_4],
outputs=[output_text4, validate_article_result_4]
)
# 為其他模型添加點擊事件
lesson_csv_file_input.change(
load_lesson_csv,
inputs=[lesson_csv_file_input],
outputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input]
)
base_chars_csv_file_input.change(
load_base_chars_csv,
inputs=[base_chars_csv_file_input],
outputs=[base_chars_input]
)
demo.launch()