Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import random | |
| import multiprocessing | |
| import os | |
| import signal | |
| import threading | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.layers import Input, Dense | |
| from tensorflow.keras.optimizers import Adam | |
| # データセットの準備 | |
| data = { | |
| "hello": "こんにちは", | |
| "world": "世界", | |
| "good": "良い", | |
| "morning": "朝", | |
| "evening": "晩", | |
| "night": "夜", | |
| "day": "日", | |
| "thank": "ありがとう", | |
| "you": "あなた", | |
| # 他のデータを追加 | |
| } | |
| input_texts = list(data.keys()) | |
| target_texts = list(data.values()) | |
| # ボキャブラリの作成 | |
| input_token_index = {word: i for i, word in enumerate(set(input_texts))} | |
| target_token_index = {word: i for i, word in enumerate(set(target_texts))} | |
| # ボキャブラリのサイズ | |
| num_encoder_tokens = len(input_token_index) | |
| num_decoder_tokens = len(target_token_index) | |
| # データのエンコーディング | |
| encoder_input_data = np.zeros((len(input_texts), num_encoder_tokens), dtype='float32') | |
| decoder_input_data = np.zeros((len(target_texts), num_decoder_tokens), dtype='float32') | |
| decoder_target_data = np.zeros((len(target_texts), num_decoder_tokens), dtype='float32') | |
| for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)): | |
| encoder_input_data[i, input_token_index[input_text]] = 1. | |
| decoder_input_data[i, target_token_index[target_text]] = 1. | |
| decoder_target_data[i, target_token_index[target_text]] = 1. | |
| # モデルの構築 | |
| encoder_inputs = Input(shape=(num_encoder_tokens,)) | |
| encoder_dense = Dense(256, activation='relu') | |
| encoder_outputs = encoder_dense(encoder_inputs) | |
| decoder_inputs = Input(shape=(num_decoder_tokens,)) | |
| decoder_dense = Dense(256, activation='relu') | |
| decoder_outputs = decoder_dense(decoder_inputs) | |
| decoder_dense_output = Dense(num_decoder_tokens, activation='softmax') | |
| decoder_outputs = decoder_dense_output(decoder_outputs) | |
| model = Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
| model.compile(optimizer=Adam(), loss='categorical_crossentropy') | |
| # 強化学習の実装 | |
| class ReinforcementTranslator: | |
| def __init__(self, model, input_token_index, target_token_index): | |
| self.model = model | |
| self.input_token_index = input_token_index | |
| self.target_token_index = target_token_index | |
| self.reverse_target_token_index = {i: word for word, i in target_token_index.items()} | |
| self.rewards = [] | |
| def translate(self, input_seq): | |
| # エンコーダーの出力 | |
| encoder_output = self.model.layers[1].predict(input_seq) | |
| # デコーダーの入力 | |
| target_seq = np.zeros((1, num_decoder_tokens)) | |
| target_seq[0, target_token_index['<start>']] = 1. | |
| stop_condition = False | |
| decoded_sentence = '' | |
| while not stop_condition: | |
| output_tokens = self.model.layers[3].predict([encoder_output, target_seq]) | |
| sampled_token_index = np.argmax(output_tokens[0]) | |
| sampled_word = self.reverse_target_token_index[sampled_token_index] | |
| decoded_sentence += ' ' + sampled_word | |
| if (sampled_word == '<end>' or len(decoded_sentence) > 50): | |
| stop_condition = True | |
| target_seq = np.zeros((1, num_decoder_tokens)) | |
| target_seq[0, sampled_token_index] = 1. | |
| return decoded_sentence.strip() | |
| def train(self, input_texts, target_texts, epochs=100): | |
| for epoch in range(epochs): | |
| total_reward = 0 | |
| for input_text, target_text in zip(input_texts, target_texts): | |
| input_seq = np.zeros((1, num_encoder_tokens)) | |
| input_seq[0, input_token_index[input_text]] = 1. | |
| predicted_translation = self.translate(input_seq) | |
| reward = self.calculate_reward(predicted_translation, target_text) | |
| total_reward += reward | |
| # モデルの更新 | |
| self.model.fit([input_seq, decoder_input_data], decoder_target_data, epochs=1, batch_size=1, verbose=0) | |
| self.rewards.append(total_reward) | |
| print(f'Epoch {epoch + 1}/{epochs}, Total Reward: {total_reward}') | |
| def calculate_reward(self, predicted, target): | |
| if predicted.strip() == target.strip(): | |
| return 1 | |
| else: | |
| return -1 | |
| # 並列トレーニング | |
| def train_model(model, input_texts, target_texts, epochs, model_id, rewards): | |
| translator = ReinforcementTranslator(model, input_token_index, target_token_index) | |
| translator.train(input_texts, target_texts, epochs) | |
| rewards[model_id] = translator.rewards | |
| return model_id, translator.model | |
| if __name__ == '__main__': | |
| # 9つのモデルを初期化 | |
| models = [tf.keras.models.clone_model(model) for _ in range(9)] | |
| rewards = {i: [] for i in range(9)} | |
| # 並列トレーニング | |
| with multiprocessing.Pool(processes=9) as pool: | |
| results = pool.starmap(train_model, [(model, input_texts, target_texts, 100, i, rewards) for i, model in enumerate(models)]) | |
| # トレーニング後のモデルを保存 | |
| for model_id, trained_model in results: | |
| trained_model.save(f'model_{model_id}.h5') | |
| # Gradio インターフェースの作成 | |
| def process_file(input_text): | |
| input_seq = np.zeros((1, num_encoder_tokens)) | |
| input_seq[0, input_token_index[input_text]] = 1. | |
| # 最も良いモデルを選択 | |
| best_model = tf.keras.models.load_model('model_0.h5') | |
| translator = ReinforcementTranslator(best_model, input_token_index, target_token_index) | |
| translation = translator.translate(input_seq) | |
| return translation.strip() | |
| def get_rewards(): | |
| return {f'Model {i}': rewards[i] for i in range(9)} | |
| def stop_gradio(): | |
| os.kill(os.getpid(), signal.SIGINT) | |
| return "Server stopping..." | |
| iface = gr.Interface( | |
| fn=process_file, | |
| inputs="text", | |
| outputs="text", | |
| title="英語単語の翻訳", | |
| description="このインターフェースでは、英語の単語を入力し、その日本語翻訳を生成します。" | |
| ) | |
| rewards_button = gr.Button("Get Rewards") | |
| rewards_output = gr.JSON() | |
| rewards_button.click(fn=get_rewards, inputs=[], outputs=rewards_output) | |
| stop_button = gr.Button("Stop Server") | |
| stop_button.click(fn=stop_gradio, inputs=[], outputs="text") | |
| # ウェブインターフェースの起動 | |
| iface.launch() |