File size: 4,114 Bytes
00512e2
 
1a3b316
90511b8
00512e2
f75e937
 
 
 
 
 
 
 
 
 
00512e2
 
 
 
 
 
288ad9e
00512e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288ad9e
00512e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import re
import tensorflow as tf 
import numpy as np
import streamlit as st

BATCH_SIZE = 128
NUM_HEADS = 12
NUM_BLOCKS = 2
EMBED_DIM = 384
DENSE_DIM = 1536
DROPOUT_RATE = 0.3
CHUNK_LENGTH = 256
vocab_size = 12050
sequence_length = CHUNK_LENGTH+1

def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(
        lowercase, f"[{re.escape(strip_chars)}]", "")

vectorized_model = tf.keras.models.load_model(
    "src/ShakespeareVect.keras",
    custom_objects={"custom_standardization": custom_standardization}
)
vectorizer = vectorized_model.layers[0]

class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, sequence_length, vocab_size, output_dim):
        super().__init__()
        self.positional_embedding = tf.keras.layers.Embedding(input_dim = sequence_length, output_dim = output_dim, mask_zero=False)
        self.token_embedding = tf.keras.layers.Embedding(input_dim = vocab_size, output_dim= output_dim, mask_zero=True)
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embedding(inputs)
        embedded_positions = self.positional_embedding(positions)
        return embedded_tokens + embedded_positions

class TransformerDecoder(tf.keras.layers.Layer):
    def __init__(self, num_heads, embed_dim, dense_dim, dropout_rate):
        super().__init__()
        self.attention = tf.keras.layers.MultiHeadAttention(num_heads=num_heads,
                                                           key_dim=embed_dim//num_heads)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dense_proj = tf.keras.models.Sequential([
            tf.keras.layers.Dense(dense_dim, activation='gelu'),
            tf.keras.layers.Dense(embed_dim)
        ])
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
    def call(self, inputs):
        attn_out = self.attention(query=inputs,
                            key=inputs,
                            value=inputs,
                            use_causal_mask=True)
        norm1_out = self.layernorm1(attn_out+inputs)
        drop1_out = self.dropout1(norm1_out)
        dense_proj_out = self.dense_proj(drop1_out)
        norm2_out = self.layernorm2(drop1_out+dense_proj_out)
        drop2_out = self.dropout2(norm2_out)
        return drop2_out

inputs = tf.keras.layers.Input(shape=(None,))
embeddings = PositionalEmbedding(sequence_length, vocab_size, EMBED_DIM)(inputs)
x = embeddings
for layer in range(NUM_BLOCKS):
    x = TransformerDecoder(NUM_HEADS, EMBED_DIM, DENSE_DIM, DROPOUT_RATE)(x)
x = tf.keras.layers.Dropout(0.3)(x)
output = tf.keras.layers.Dense(vocab_size, activation='linear', kernel_initializer='glorot_uniform')(x)
transformer = tf.keras.models.Model(inputs, output)

transformer.load_weights('src/Shakespeare_decoder.weights (1).h5')
def generate_text(prompt, max_length=50, temperature=1.0):
    for _ in range(max_length):
        tokenized = vectorizer([prompt])
        tokenized_np = tokenized.numpy()[0]
        
        # Find the last non-padding token
        last_idx = np.max(np.nonzero(tokenized_np))
        
        preds = transformer(tokenized, training=False)
        logits = preds[0, last_idx, :].numpy()
        
        # Apply temperature sampling
        probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
        next_id = np.random.choice(len(probs), p=probs)
        
        next_word = vectorizer.get_vocabulary()[next_id]
        if next_word in ("", "[UNK]"):
            break
        
        prompt += " " + next_word
    return prompt


user_input = st.text_input("Enter some text:", "")
 
if user_input != "":
    with st.spinner("Generating Text..."):
        text = generate_text(user_input)
        st.text(text)