Sakalti commited on
Commit
ca0df76
·
verified ·
1 Parent(s): 89d87d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import random
5
+ import multiprocessing
6
+ import os
7
+ import signal
8
+ import threading
9
+ from tensorflow.keras.models import Model
10
+ from tensorflow.keras.layers import Input, Dense
11
+ from tensorflow.keras.optimizers import Adam
12
+
13
+ # データセットの準備
14
+ data = {
15
+ "hello": "こんにちは",
16
+ "world": "世界",
17
+ "good": "良い",
18
+ "morning": "朝",
19
+ "evening": "晩",
20
+ "night": "夜",
21
+ "day": "日",
22
+ "thank": "ありがとう",
23
+ "you": "あなた",
24
+ # 他のデータを追加
25
+ }
26
+
27
+ input_texts = list(data.keys())
28
+ target_texts = list(data.values())
29
+
30
+ # ボキャブラリの作成
31
+ input_token_index = {word: i for i, word in enumerate(set(input_texts))}
32
+ target_token_index = {word: i for i, word in enumerate(set(target_texts))}
33
+
34
+ # ボキャブラリのサイズ
35
+ num_encoder_tokens = len(input_token_index)
36
+ num_decoder_tokens = len(target_token_index)
37
+
38
+ # データのエンコーディング
39
+ encoder_input_data = np.zeros((len(input_texts), num_encoder_tokens), dtype='float32')
40
+ decoder_input_data = np.zeros((len(target_texts), num_decoder_tokens), dtype='float32')
41
+ decoder_target_data = np.zeros((len(target_texts), num_decoder_tokens), dtype='float32')
42
+
43
+ for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
44
+ encoder_input_data[i, input_token_index[input_text]] = 1.
45
+ decoder_input_data[i, target_token_index[target_text]] = 1.
46
+ decoder_target_data[i, target_token_index[target_text]] = 1.
47
+
48
+ # モデルの構築
49
+ encoder_inputs = Input(shape=(num_encoder_tokens,))
50
+ encoder_dense = Dense(256, activation='relu')
51
+ encoder_outputs = encoder_dense(encoder_inputs)
52
+
53
+ decoder_inputs = Input(shape=(num_decoder_tokens,))
54
+ decoder_dense = Dense(256, activation='relu')
55
+ decoder_outputs = decoder_dense(decoder_inputs)
56
+
57
+ decoder_dense_output = Dense(num_decoder_tokens, activation='softmax')
58
+ decoder_outputs = decoder_dense_output(decoder_outputs)
59
+
60
+ model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
61
+ model.compile(optimizer=Adam(), loss='categorical_crossentropy')
62
+
63
+ # 強化学習の実装
64
+ class ReinforcementTranslator:
65
+ def __init__(self, model, input_token_index, target_token_index):
66
+ self.model = model
67
+ self.input_token_index = input_token_index
68
+ self.target_token_index = target_token_index
69
+ self.reverse_target_token_index = {i: word for word, i in target_token_index.items()}
70
+ self.rewards = []
71
+
72
+ def translate(self, input_seq):
73
+ # エンコーダーの出力
74
+ encoder_output = self.model.layers[1].predict(input_seq)
75
+
76
+ # デコーダーの入力
77
+ target_seq = np.zeros((1, num_decoder_tokens))
78
+ target_seq[0, target_token_index['<start>']] = 1.
79
+
80
+ stop_condition = False
81
+ decoded_sentence = ''
82
+ while not stop_condition:
83
+ output_tokens = self.model.layers[3].predict([encoder_output, target_seq])
84
+ sampled_token_index = np.argmax(output_tokens[0])
85
+ sampled_word = self.reverse_target_token_index[sampled_token_index]
86
+ decoded_sentence += ' ' + sampled_word
87
+
88
+ if (sampled_word == '<end>' or len(decoded_sentence) > 50):
89
+ stop_condition = True
90
+
91
+ target_seq = np.zeros((1, num_decoder_tokens))
92
+ target_seq[0, sampled_token_index] = 1.
93
+
94
+ return decoded_sentence.strip()
95
+
96
+ def train(self, input_texts, target_texts, epochs=100):
97
+ for epoch in range(epochs):
98
+ total_reward = 0
99
+ for input_text, target_text in zip(input_texts, target_texts):
100
+ input_seq = np.zeros((1, num_encoder_tokens))
101
+ input_seq[0, input_token_index[input_text]] = 1.
102
+
103
+ predicted_translation = self.translate(input_seq)
104
+ reward = self.calculate_reward(predicted_translation, target_text)
105
+ total_reward += reward
106
+
107
+ # モデルの更新
108
+ self.model.fit([input_seq, decoder_input_data], decoder_target_data, epochs=1, batch_size=1, verbose=0)
109
+
110
+ self.rewards.append(total_reward)
111
+ print(f'Epoch {epoch + 1}/{epochs}, Total Reward: {total_reward}')
112
+
113
+ def calculate_reward(self, predicted, target):
114
+ if predicted.strip() == target.strip():
115
+ return 1
116
+ else:
117
+ return -1
118
+
119
+ # 並列トレーニング
120
+ def train_model(model, input_texts, target_texts, epochs, model_id, rewards):
121
+ translator = ReinforcementTranslator(model, input_token_index, target_token_index)
122
+ translator.train(input_texts, target_texts, epochs)
123
+ rewards[model_id] = translator.rewards
124
+ return model_id, translator.model
125
+
126
+ if __name__ == '__main__':
127
+ # 9つのモデルを初期化
128
+ models = [tf.keras.models.clone_model(model) for _ in range(9)]
129
+ rewards = {i: [] for i in range(9)}
130
+
131
+ # 並列トレーニング
132
+ with multiprocessing.Pool(processes=9) as pool:
133
+ results = pool.starmap(train_model, [(model, input_texts, target_texts, 100, i, rewards) for i, model in enumerate(models)])
134
+
135
+ # トレーニング後のモデルを保存
136
+ for model_id, trained_model in results:
137
+ trained_model.save(f'model_{model_id}.h5')
138
+
139
+ # Gradio インターフェースの作成
140
+ def process_file(input_text):
141
+ input_seq = np.zeros((1, num_encoder_tokens))
142
+ input_seq[0, input_token_index[input_text]] = 1.
143
+
144
+ # 最も良いモデルを選択
145
+ best_model = tf.keras.models.load_model('model_0.h5')
146
+ translator = ReinforcementTranslator(best_model, input_token_index, target_token_index)
147
+ translation = translator.translate(input_seq)
148
+ return translation.strip()
149
+
150
+ def get_rewards():
151
+ return {f'Model {i}': rewards[i] for i in range(9)}
152
+
153
+ def stop_gradio():
154
+ os.kill(os.getpid(), signal.SIGINT)
155
+ return "Server stopping..."
156
+
157
+ iface = gr.Interface(
158
+ fn=process_file,
159
+ inputs="text",
160
+ outputs="text",
161
+ title="英語単語の翻訳",
162
+ description="このインターフェースでは、英語の単語を入力し、その日本語翻訳を生成します。"
163
+ )
164
+
165
+ rewards_button = gr.Button("Get Rewards")
166
+ rewards_output = gr.JSON()
167
+ rewards_button.click(fn=get_rewards, inputs=[], outputs=rewards_output)
168
+
169
+ stop_button = gr.Button("Stop Server")
170
+ stop_button.click(fn=stop_gradio, inputs=[], outputs="text")
171
+
172
+ # ウェブインターフェースの起動
173
+ iface.launch()