Spaces:
Running
Running
| import tensorflow as tf | |
| from tensorflow.keras.layers import TextVectorization, Embedding, MultiHeadAttention, LayerNormalization, Dense, Dropout | |
| from tensorflow.keras.models import Model | |
| import gradio as gr | |
| import json | |
| START_TOKEN = '<start>' | |
| END_TOKEN = '<end>' | |
| class TransformerBlock(tf.keras.layers.Layer): | |
| def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs): | |
| super().__init__(**kwargs) | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.ff_dim = ff_dim | |
| self.rate = rate | |
| self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) | |
| self.ffn = tf.keras.Sequential([ | |
| Dense(ff_dim, activation='relu'), | |
| Dense(embed_dim), | |
| ]) | |
| self.layernorm1 = LayerNormalization(epsilon=1e-5) | |
| self.layernorm2 = LayerNormalization(epsilon=1e-5) | |
| self.dropout1 = Dropout(rate) | |
| self.dropout2 = Dropout(rate) | |
| def call(self, inputs, training=None): | |
| attn_output = self.att(inputs, inputs) | |
| attn_output = self.dropout1(attn_output, training=training) | |
| out1 = self.layernorm1(inputs + attn_output) | |
| ffn_output = self.ffn(out1) | |
| ffn_output = self.dropout2(ffn_output, training=training) | |
| return self.layernorm2(out1 + ffn_output) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| 'embed_dim': self.embed_dim, | |
| 'num_heads': self.num_heads, | |
| 'ff_dim': self.ff_dim, | |
| 'rate': self.rate, | |
| }) | |
| return config | |
| class TokenAndPositionEmbedding(tf.keras.layers.Layer): | |
| def __init__(self, maxlen, vocab_size, embed_dim, **kwargs): | |
| super().__init__(**kwargs) | |
| self.maxlen = maxlen | |
| self.vocab_size = vocab_size | |
| self.embed_dim = embed_dim | |
| self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim) | |
| self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim) | |
| def call(self, x): | |
| maxlen = tf.shape(x)[-1] | |
| positions = tf.range(start=0, limit=maxlen, delta=1) | |
| positions = self.pos_emb(positions) | |
| x = self.token_emb(x) | |
| return x + positions | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| 'maxlen': self.maxlen, | |
| 'vocab_size': self.vocab_size, | |
| 'embed_dim': self.embed_dim, | |
| }) | |
| return config | |
| def load_model(filename="tg-medium"): | |
| model = tf.keras.models.load_model(f'{filename}.h5', custom_objects={ | |
| 'TokenAndPositionEmbedding': TokenAndPositionEmbedding, | |
| 'TransformerBlock': TransformerBlock | |
| }) | |
| with open(f'{filename}.json', 'r', encoding='utf-8') as f: | |
| vocab = json.load(f) | |
| vectorizer = TextVectorization( | |
| max_tokens=128000, | |
| output_sequence_length=100, | |
| standardize=None, | |
| vocabulary=vocab | |
| ) | |
| return model, vectorizer | |
| def generate_text(model, vectorizer, prompt): | |
| prompt = START_TOKEN + ' ' + prompt + ' ' + END_TOKEN | |
| input_seq = vectorizer([prompt]) | |
| input_seq = input_seq[:, :-1] | |
| predictions = model.predict(input_seq) | |
| predicted_tokens = tf.argmax(predictions[0], axis=-1) | |
| vocab = vectorizer.get_vocabulary() | |
| output_tokens = [vocab[idx] for idx in predicted_tokens.numpy()] | |
| if END_TOKEN in output_tokens: | |
| end_index = output_tokens.index(END_TOKEN) | |
| output_tokens = output_tokens[:end_index] | |
| if START_TOKEN in output_tokens: | |
| output_tokens.remove(START_TOKEN) | |
| output = ' '.join(output_tokens) | |
| return output | |
| def main(): | |
| model, vectorizer = load_model() | |
| def generate_response(prompt): | |
| return generate_text(model, vectorizer, prompt) | |
| iface = gr.Interface( | |
| fn=generate_response, | |
| inputs=gr.Textbox(lines=2, placeholder="Start your conversation."), | |
| outputs="text", | |
| title="tg-medium", | |
| description="Interference API. (russian only)" | |
| ) | |
| iface.launch() | |
| if __name__ == "__main__": | |
| main() |