Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import pandas as pd | |
| from openai import OpenAI | |
| import json | |
| from google.cloud import aiplatform | |
| from vertexai.preview.generative_models import GenerativeModel | |
| from google.oauth2.service_account import Credentials | |
| # 設置 OpenAI API 客戶端 | |
| IS_ENV_LOCAL = os.getenv("IS_ENV_LOCAL", False) | |
| if IS_ENV_LOCAL: | |
| local_json = json.load(open("local.json")) | |
| openai_api_key = local_json["OPENAI_API_KEY"] | |
| GOOGLE_SERVICE_ACCOUNT_INFO = local_json["GBQ_TOKEN"] | |
| google_service_account_info_dict = GOOGLE_SERVICE_ACCOUNT_INFO | |
| else: | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| GOOGLE_SERVICE_ACCOUNT_INFO = os.getenv("GBQ_TOKEN") | |
| google_service_account_info_dict = json.loads(GOOGLE_SERVICE_ACCOUNT_INFO) | |
| # OPENAI | |
| OPENAI_CLIENT = OpenAI(api_key=openai_api_key) | |
| GOOGPE_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] | |
| google_creds = Credentials.from_service_account_info( | |
| google_service_account_info_dict, scopes=GOOGPE_SCOPES | |
| ) | |
| aiplatform.init( | |
| project="junyiacademy", | |
| service_account=google_service_account_info_dict, | |
| credentials=google_creds, | |
| ) | |
| GEMINI_MODEL = GenerativeModel("gemini-pro") | |
| def extract_article_from_content(article_text): | |
| start_markers = ["新文章:", "New Article:", "Here it is:"] | |
| end_marker = "\nThank you" | |
| for start_marker in start_markers: | |
| start_index = article_text.find(start_marker) | |
| if start_index != -1: | |
| start_index += len(start_marker) | |
| while article_text[start_index] in "\n": | |
| start_index += 1 | |
| end_index = article_text.find(end_marker, start_index) | |
| if end_index != -1: | |
| return article_text[start_index:end_index].rstrip() | |
| return article_text[start_index:].rstrip() | |
| return article_text | |
| def validate_article(generated_article, lesson_words, base_chars, original_word_count): | |
| clean_article = "".join(char for char in generated_article if char not in "、,。!?;:「」『』()《》【】'\n'") | |
| not_every_new_word_is_used = not all(word in clean_article for word in [char for char in lesson_words]) | |
| word_out_of_range = not set(clean_article).issubset(set(lesson_words + base_chars)) | |
| new_word_count = len(clean_article) | |
| word_count_error = not (0.9 * original_word_count <= new_word_count <= 1.1 * original_word_count) | |
| lesson_words_not_in_new_article = [word for word in [char for char in lesson_words] if word not in clean_article] | |
| words_not_in_both = [word for word in lesson_words_not_in_new_article if word not in base_chars] | |
| additional_words = set([word for word in [char for char in clean_article] if word not in [char for char in lesson_words] and word not in [char for char in base_chars]]) | |
| count_of_words_in_new_article = len(clean_article) | |
| result = { | |
| "not_every_new_word_is_used": not_every_new_word_is_used, | |
| "word_out_of_range": word_out_of_range, | |
| "word_count_error": word_count_error, | |
| "lesson_words_not_in_new_article": lesson_words_not_in_new_article, | |
| "words_not_in_both": words_not_in_both, | |
| "additional_words": additional_words, | |
| "count_of_words_in_new_article": count_of_words_in_new_article | |
| } | |
| return result | |
| def generate_new_article(lesson_words, original_article, original_word_count, base_chars, model_name): | |
| # check lesson_words, original_article, original_word_count exist | |
| if not lesson_words or not original_article or not original_word_count: | |
| raise gr.Error("lesson_words, original_article, original_word_count are required. Please upload the lesson csv file.") | |
| attempt = 0 | |
| max_attempts = 3 | |
| generated_article = "" | |
| while attempt < max_attempts: | |
| attempt += 1 | |
| print("================Attempt=====================") | |
| print(f"Attempt {attempt} to generate new article") | |
| print("===========================================") | |
| system_prompt = "You are a creative writer specialized in Chinese Children book. You will help me write Chinese Articles." | |
| prompt = f""" | |
| Please write a new and original Chinese article tailored for first-grade students. Here's a summary of the key points that you should follow: | |
| Use Traditional Chinese (ZH-TW) Characters: The article should be written in Traditional Chinese(ZH-TW), not Simplified Chinese. | |
| Adherence to the Original Article: The new creation should closely follow the spirit, style, and rhythmic pattern of the provided original article. The number of words, excluding punctuation marks, should be similar to that of the original, approximately {original_word_count} words. | |
| Incorporate "New Words": Every word listed under "new words" must be used in the article. These words are: {lesson_words}. | |
| Utilize the "Word Library": Additional words required for the article can be selected from the provided "word library," which includes: {base_chars}. | |
| Restriction on Vocabulary: Do not use any words outside the "new words" or the "word library". | |
| Originality: The new article must be unique and original, not a copy of the original work. | |
| "Original Article" for Reference: The example provided is {original_article}。This article serves as a model for the spirit, style, and rhythmic pattern to be emulated. | |
| """ | |
| if model_name in ["gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]: | |
| response = OPENAI_CLIENT.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=1000 | |
| ) | |
| generated_text = response.choices[0].message.content.strip() | |
| elif model_name == "gemini-pro": | |
| model_response = GEMINI_MODEL.generate_content( | |
| f"{system_prompt}, {prompt}" | |
| ) | |
| generated_text = model_response.candidates[0].content.parts[0].text | |
| generated_article = extract_article_from_content(generated_text) | |
| validate_article_result = validate_article(generated_article, lesson_words, base_chars, original_word_count) | |
| not_every_new_word_is_used = validate_article_result['not_every_new_word_is_used'] | |
| word_out_of_range = validate_article_result['word_out_of_range'] | |
| word_count_error = validate_article_result['word_count_error'] | |
| count_of_words_in_new_article = validate_article_result['count_of_words_in_new_article'] | |
| print("====validate_article====") | |
| print(f"not_every_new_word_is_used: {not_every_new_word_is_used}") | |
| print(f"word_out_of_range: {word_out_of_range}") | |
| print(f"word_count_error: {word_count_error}") | |
| print("=========================") | |
| if not not_every_new_word_is_used and not word_out_of_range and not word_count_error: | |
| print("Generated article is valid") | |
| break | |
| else: | |
| print("Generated article is invalid") | |
| error_messages = [] | |
| if not_every_new_word_is_used: | |
| error_messages.append("Not every new word is used in the article.") | |
| if word_out_of_range: | |
| error_messages.append("The article contains words that are not in the new words or word library.") | |
| if word_count_error: | |
| error_messages.append(f"The word count of the new article deviates more than 10% from the original ({original_word_count}).") | |
| error_messages_str = "\n".join(error_messages) + "\n" # Append the error messages to the prompt for the next attempt | |
| prompt += f""" | |
| The new article is {generated_article}. | |
| word_count is {count_of_words_in_new_article}. | |
| But the generated article is invalid. The following issues were found: | |
| {error_messages_str} | |
| please follow the summary of the key points and fix the errors to generate a new article. | |
| """ | |
| print(f"Prompt for next attempt: {prompt}") | |
| return generated_article, validate_article_result | |
| def load_lesson_csv(file): | |
| try: | |
| df = pd.read_csv(file, encoding='utf-8') | |
| except: | |
| df = pd.read_csv(file.name, encoding='utf-8') | |
| if not df.empty: | |
| first_row = df.iloc[0] | |
| lesson_words = first_row['lesson_words'] | |
| original_article = first_row['lesson_article'] | |
| clean_original_article = [char for char in original_article if char not in "、,。!?;:「」『』()《》【】'\n''\r'"] | |
| original_word_count = len(clean_original_article) | |
| base_chars = '' # This should be defined or extracted from some column or external source | |
| return lesson_words, original_article, original_word_count, base_chars | |
| return "", "", 0, "" | |
| def load_base_chars_csv(file): | |
| try: | |
| df = pd.read_csv(file, encoding='utf-8') | |
| except: | |
| df = pd.read_csv(file.name, encoding='utf-8') | |
| if not df.empty: | |
| first_row = df.iloc[0] | |
| base_chars = first_row['words'] | |
| return base_chars | |
| return "" | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Original Lesson CSV File") | |
| lesson_csv_file_input = gr.File(label="Upload CSV file (Columns: lesson_words, original_article, original_word_count, base_chars)") | |
| with gr.Column(): | |
| gr.Markdown("### Base Characters CSV File") | |
| base_chars_csv_file_input = gr.File(label="Upload Base Characters File") | |
| with gr.Row(): | |
| lesson_words_input = gr.Textbox(label="Lesson Words") | |
| original_article_input = gr.Textbox(label="Original Article") | |
| original_word_count_input = gr.Number(label="Original Word Count") | |
| base_chars_input = gr.Textbox(label="Base Characters") | |
| with gr.Row(): | |
| model_list = ["gpt-4-0125-preview", "gpt-3.5-turbo", "gpt-4", "gemini-pro"] | |
| with gr.Column(): | |
| model_1 = gr.Dropdown(label="Model 1", choices=model_list, value="gpt-4-0125-preview") | |
| generate_button1 = gr.Button("Generate Article - gpt-4-0125-preview") | |
| with gr.Column(): | |
| output_text1 = gr.Textbox(label="Generated Article - gpt-4-0125-preview") | |
| with gr.Column(): | |
| validate_article_result_1 = gr.JSON() | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_2 = gr.Dropdown(label="Model 2", choices=model_list, value="gpt-3.5-turbo") | |
| generate_button2 = gr.Button("Generate Article - gpt-3.5-turbo") | |
| with gr.Column(): | |
| output_text2 = gr.Textbox(label="Generated Article - gpt-3.5-turbo") | |
| with gr.Column(): | |
| validate_article_result_2 = gr.JSON() | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_3 = gr.Dropdown(label="Model 3", choices=model_list, value="gpt-4") | |
| generate_button3 = gr.Button("Generate Article - gpt-4") | |
| with gr.Column(): | |
| output_text3 = gr.Textbox(label="Generated Article - gpt-4") | |
| with gr.Column(): | |
| # validate_article_result_3 Json format | |
| validate_article_result_3 = gr.JSON() | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_4 = gr.Dropdown(label="Model 4", choices=model_list, value="gemini-pro") | |
| generate_button4 = gr.Button("Generate Article - gemini-pro") | |
| with gr.Column(): | |
| output_text4 = gr.Textbox(label="Generated Article - gemini-pro") | |
| with gr.Column(): | |
| validate_article_result_4 = gr.JSON() | |
| generate_button1.click( | |
| generate_new_article, | |
| inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_1], | |
| outputs=[output_text1, validate_article_result_1] | |
| ) | |
| generate_button2.click( | |
| generate_new_article, | |
| inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_2], | |
| outputs=[output_text2, validate_article_result_2] | |
| ) | |
| generate_button3.click( | |
| generate_new_article, | |
| inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_3], | |
| outputs=[output_text3, validate_article_result_3] | |
| ) | |
| generate_button4.click( | |
| generate_new_article, | |
| inputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input, model_4], | |
| outputs=[output_text4, validate_article_result_4] | |
| ) | |
| # 為其他模型添加點擊事件 | |
| lesson_csv_file_input.change( | |
| load_lesson_csv, | |
| inputs=[lesson_csv_file_input], | |
| outputs=[lesson_words_input, original_article_input, original_word_count_input, base_chars_input] | |
| ) | |
| base_chars_csv_file_input.change( | |
| load_base_chars_csv, | |
| inputs=[base_chars_csv_file_input], | |
| outputs=[base_chars_input] | |
| ) | |
| demo.launch() | |