veda-programming / model.py
vedaco's picture
Update model.py
54392ea verified
"""Veda Programming Assistant Model"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
class VedaProgrammingLLM(keras.Model):
"""Conversational Programming Assistant LLM"""
def __init__(
self,
vocab_size: int,
max_length: int = 512,
d_model: int = 256,
num_heads: int = 8,
num_layers: int = 4,
ff_dim: int = 512,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_length = max_length
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.ff_dim = ff_dim
self.token_embedding = layers.Embedding(vocab_size, d_model)
self.pos_embedding = layers.Embedding(max_length, d_model)
self.dropout = layers.Dropout(0.1)
self.attn_layers = []
self.ffn_layers = []
self.ln1_layers = []
self.ln2_layers = []
for _ in range(num_layers):
self.attn_layers.append(
layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=d_model // num_heads,
dropout=0.1
)
)
self.ffn_layers.append(
keras.Sequential([
layers.Dense(ff_dim, activation='gelu'),
layers.Dropout(0.1),
layers.Dense(d_model),
layers.Dropout(0.1)
])
)
self.ln1_layers.append(layers.LayerNormalization(epsilon=1e-6))
self.ln2_layers.append(layers.LayerNormalization(epsilon=1e-6))
self.final_ln = layers.LayerNormalization(epsilon=1e-6)
self.output_layer = layers.Dense(vocab_size)
def call(self, inputs, training=False):
seq_len = tf.shape(inputs)[1]
mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
positions = tf.range(seq_len)
x = self.token_embedding(inputs)
x = x * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x = x + self.pos_embedding(positions)
x = self.dropout(x, training=training)
for i in range(self.num_layers):
attn_out = self.attn_layers[i](x, x, attention_mask=mask, training=training)
x = self.ln1_layers[i](x + attn_out)
ffn_out = self.ffn_layers[i](x, training=training)
x = self.ln2_layers[i](x + ffn_out)
x = self.final_ln(x)
return self.output_layer(x)
def generate(
self,
prompt_tokens: list,
max_new_tokens: int = 200,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.2,
stop_tokens: list = None
) -> list:
"""Generate response"""
generated = list(prompt_tokens)
for _ in range(max_new_tokens):
context = generated[-self.max_length:]
input_tensor = tf.constant([context], dtype=tf.int32)
logits = self(input_tensor, training=False)
next_logits = logits[0, -1, :].numpy().astype(np.float64)
if repetition_penalty != 1.0:
for token_id in set(generated[-100:]):
if 0 <= token_id < len(next_logits):
if next_logits[token_id] > 0:
next_logits[token_id] /= repetition_penalty
else:
next_logits[token_id] *= repetition_penalty
next_logits = next_logits / max(temperature, 0.1)
if top_k > 0 and top_k < len(next_logits):
indices_to_remove = next_logits < np.partition(next_logits, -top_k)[-top_k]
next_logits[indices_to_remove] = -np.inf
if top_p < 1.0:
sorted_indices = np.argsort(next_logits)[::-1]
sorted_logits = next_logits[sorted_indices]
max_logit = np.max(sorted_logits[sorted_logits > -np.inf]) if np.any(sorted_logits > -np.inf) else 0
exp_logits = np.exp(sorted_logits - max_logit)
probs = exp_logits / (np.sum(exp_logits) + 1e-10)
cumulative = np.cumsum(probs)
remove_mask = cumulative > top_p
remove_mask[1:] = remove_mask[:-1].copy()
remove_mask[0] = False
next_logits[sorted_indices[remove_mask]] = -np.inf
max_logit = np.max(next_logits[next_logits > -np.inf]) if np.any(next_logits > -np.inf) else 0
exp_logits = np.exp(next_logits - max_logit)
exp_logits[next_logits == -np.inf] = 0
probs = exp_logits / (np.sum(exp_logits) + 1e-10)
probs = np.clip(probs, 0, 1)
prob_sum = np.sum(probs)
if prob_sum > 0:
probs = probs / prob_sum
else:
probs = np.ones_like(probs) / len(probs)
try:
next_token = np.random.choice(len(probs), p=probs)
except ValueError:
next_token = np.argmax(probs)
generated.append(int(next_token))
if next_token == 0 or next_token == 3:
break
if stop_tokens and next_token in stop_tokens:
break
return generated
def get_config(self):
return {
'vocab_size': self.vocab_size,
'max_length': self.max_length,
'd_model': self.d_model,
'num_heads': self.num_heads,
'num_layers': self.num_layers,
'ff_dim': self.ff_dim
}