Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +92 -39
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,93 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import tensorflow as tf
|
| 3 |
import numpy as np
|
| 4 |
+
|
| 5 |
+
def custom_standardization(input_string):
|
| 6 |
+
lowercase = tf.strings.lower(input_string)
|
| 7 |
+
return tf.strings.regex_replace(
|
| 8 |
+
lowercase, f"[{re.escape(strip_chars)}]", "")
|
| 9 |
+
|
| 10 |
+
vectorized_model = tf.keras.models.load_model(
|
| 11 |
+
"ShakespeareVect.keras",
|
| 12 |
+
custom_objects={"custom_standardization": custom_standardization}
|
| 13 |
+
)
|
| 14 |
+
vectorizer = vectorized_model.layers[0]
|
| 15 |
+
|
| 16 |
+
class PositionalEmbedding(tf.keras.layers.Layer):
|
| 17 |
+
def __init__(self, sequence_length, vocab_size, output_dim):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.positional_embedding = tf.keras.layers.Embedding(input_dim = sequence_length, output_dim = output_dim, mask_zero=False)
|
| 20 |
+
self.token_embedding = tf.keras.layers.Embedding(input_dim = vocab_size, output_dim= output_dim, mask_zero=True)
|
| 21 |
+
self.sequence_length = sequence_length
|
| 22 |
+
self.vocab_size = vocab_size
|
| 23 |
+
def call(self, inputs):
|
| 24 |
+
length = tf.shape(inputs)[-1]
|
| 25 |
+
positions = tf.range(start=0, limit=length, delta=1)
|
| 26 |
+
embedded_tokens = self.token_embedding(inputs)
|
| 27 |
+
embedded_positions = self.positional_embedding(positions)
|
| 28 |
+
return embedded_tokens + embedded_positions
|
| 29 |
+
|
| 30 |
+
class TransformerDecoder(tf.keras.layers.Layer):
|
| 31 |
+
def __init__(self, num_heads, embed_dim, dense_dim, dropout_rate):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.attention = tf.keras.layers.MultiHeadAttention(num_heads=num_heads,
|
| 34 |
+
key_dim=embed_dim//num_heads)
|
| 35 |
+
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
|
| 36 |
+
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
|
| 37 |
+
self.dense_proj = tf.keras.models.Sequential([
|
| 38 |
+
tf.keras.layers.Dense(dense_dim, activation='gelu'),
|
| 39 |
+
tf.keras.layers.Dense(embed_dim)
|
| 40 |
+
])
|
| 41 |
+
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
|
| 42 |
+
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
|
| 43 |
+
def call(self, inputs):
|
| 44 |
+
attn_out = self.attention(query=inputs,
|
| 45 |
+
key=inputs,
|
| 46 |
+
value=inputs,
|
| 47 |
+
use_causal_mask=True)
|
| 48 |
+
norm1_out = self.layernorm1(attn_out+inputs)
|
| 49 |
+
drop1_out = self.dropout1(norm1_out)
|
| 50 |
+
dense_proj_out = self.dense_proj(drop1_out)
|
| 51 |
+
norm2_out = self.layernorm2(drop1_out+dense_proj_out)
|
| 52 |
+
drop2_out = self.dropout2(norm2_out)
|
| 53 |
+
return drop2_out
|
| 54 |
+
|
| 55 |
+
inputs = tf.keras.layers.Input(shape=(None,))
|
| 56 |
+
embeddings = PositionalEmbedding(sequence_length, vocab_size, EMBED_DIM)(inputs)
|
| 57 |
+
x = embeddings
|
| 58 |
+
for layer in range(NUM_BLOCKS):
|
| 59 |
+
x = TransformerDecoder(NUM_HEADS, EMBED_DIM, DENSE_DIM, DROPOUT_RATE)(x)
|
| 60 |
+
x = tf.keras.layers.Dropout(0.3)(x)
|
| 61 |
+
output = tf.keras.layers.Dense(vocab_size, activation='linear', kernel_initializer='glorot_uniform')(x)
|
| 62 |
+
transformer = tf.keras.models.Model(inputs, output)
|
| 63 |
+
|
| 64 |
+
transformer.load_weights('Shakespeare_decoder.weights (1).h5')
|
| 65 |
+
def generate_text(prompt, max_length=50, temperature=1.0):
|
| 66 |
+
for _ in range(max_length):
|
| 67 |
+
tokenized = vectorizer([prompt])
|
| 68 |
+
tokenized_np = tokenized.numpy()[0]
|
| 69 |
+
|
| 70 |
+
# Find the last non-padding token
|
| 71 |
+
last_idx = np.max(np.nonzero(tokenized_np))
|
| 72 |
+
|
| 73 |
+
preds = transformer(tokenized, training=False)
|
| 74 |
+
logits = preds[0, last_idx, :].numpy()
|
| 75 |
+
|
| 76 |
+
# Apply temperature sampling
|
| 77 |
+
probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
|
| 78 |
+
next_id = np.random.choice(len(probs), p=probs)
|
| 79 |
+
|
| 80 |
+
next_word = vectorizer.get_vocabulary()[next_id]
|
| 81 |
+
if next_word in ("", "[UNK]"):
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
prompt += " " + next_word
|
| 85 |
+
return prompt
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
user_input = st.text_input("Enter some text:", "")
|
| 89 |
+
|
| 90 |
+
if user_input != "":
|
| 91 |
+
with st.spinner("Generating Text..."):
|
| 92 |
+
text = generate_text(user_input)
|
| 93 |
+
st.text(text)
|