veda-programming / model.py
vedaco's picture
Update model.py
54392ea verified
raw
history blame
5.98 kB
"""Veda Programming Assistant Model"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
class VedaProgrammingLLM(keras.Model):
"""Conversational Programming Assistant LLM"""
def __init__(
self,
vocab_size: int,
max_length: int = 512,
d_model: int = 256,
num_heads: int = 8,
num_layers: int = 4,
ff_dim: int = 512,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_length = max_length
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.ff_dim = ff_dim
self.token_embedding = layers.Embedding(vocab_size, d_model)
self.pos_embedding = layers.Embedding(max_length, d_model)
self.dropout = layers.Dropout(0.1)
self.attn_layers = []
self.ffn_layers = []
self.ln1_layers = []
self.ln2_layers = []
for _ in range(num_layers):
self.attn_layers.append(
layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=d_model // num_heads,
dropout=0.1
)
)
self.ffn_layers.append(
keras.Sequential([
layers.Dense(ff_dim, activation='gelu'),
layers.Dropout(0.1),
layers.Dense(d_model),
layers.Dropout(0.1)
])
)
self.ln1_layers.append(layers.LayerNormalization(epsilon=1e-6))
self.ln2_layers.append(layers.LayerNormalization(epsilon=1e-6))
self.final_ln = layers.LayerNormalization(epsilon=1e-6)
self.output_layer = layers.Dense(vocab_size)
def call(self, inputs, training=False):
seq_len = tf.shape(inputs)[1]
mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
positions = tf.range(seq_len)
x = self.token_embedding(inputs)
x = x * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x = x + self.pos_embedding(positions)
x = self.dropout(x, training=training)
for i in range(self.num_layers):
attn_out = self.attn_layers[i](x, x, attention_mask=mask, training=training)
x = self.ln1_layers[i](x + attn_out)
ffn_out = self.ffn_layers[i](x, training=training)
x = self.ln2_layers[i](x + ffn_out)
x = self.final_ln(x)
return self.output_layer(x)
def generate(
self,
prompt_tokens: list,
max_new_tokens: int = 200,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.2,
stop_tokens: list = None
) -> list:
"""Generate response"""
generated = list(prompt_tokens)
for _ in range(max_new_tokens):
context = generated[-self.max_length:]
input_tensor = tf.constant([context], dtype=tf.int32)
logits = self(input_tensor, training=False)
next_logits = logits[0, -1, :].numpy().astype(np.float64)
if repetition_penalty != 1.0:
for token_id in set(generated[-100:]):
if 0 <= token_id < len(next_logits):
if next_logits[token_id] > 0:
next_logits[token_id] /= repetition_penalty
else:
next_logits[token_id] *= repetition_penalty
next_logits = next_logits / max(temperature, 0.1)
if top_k > 0 and top_k < len(next_logits):
indices_to_remove = next_logits < np.partition(next_logits, -top_k)[-top_k]
next_logits[indices_to_remove] = -np.inf
if top_p < 1.0:
sorted_indices = np.argsort(next_logits)[::-1]
sorted_logits = next_logits[sorted_indices]
max_logit = np.max(sorted_logits[sorted_logits > -np.inf]) if np.any(sorted_logits > -np.inf) else 0
exp_logits = np.exp(sorted_logits - max_logit)
probs = exp_logits / (np.sum(exp_logits) + 1e-10)
cumulative = np.cumsum(probs)
remove_mask = cumulative > top_p
remove_mask[1:] = remove_mask[:-1].copy()
remove_mask[0] = False
next_logits[sorted_indices[remove_mask]] = -np.inf
max_logit = np.max(next_logits[next_logits > -np.inf]) if np.any(next_logits > -np.inf) else 0
exp_logits = np.exp(next_logits - max_logit)
exp_logits[next_logits == -np.inf] = 0
probs = exp_logits / (np.sum(exp_logits) + 1e-10)
probs = np.clip(probs, 0, 1)
prob_sum = np.sum(probs)
if prob_sum > 0:
probs = probs / prob_sum
else:
probs = np.ones_like(probs) / len(probs)
try:
next_token = np.random.choice(len(probs), p=probs)
except ValueError:
next_token = np.argmax(probs)
generated.append(int(next_token))
if next_token == 0 or next_token == 3:
break
if stop_tokens and next_token in stop_tokens:
break
return generated
def get_config(self):
return {
'vocab_size': self.vocab_size,
'max_length': self.max_length,
'd_model': self.d_model,
'num_heads': self.num_heads,
'num_layers': self.num_layers,
'ff_dim': self.ff_dim
}