Upload app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,8 @@ import os
|
|
| 10 |
|
| 11 |
# Load the model from Hugging Face Hub
|
| 12 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Define the SmolLM2-135M model (a simplified version of a Transformer)
|
| 15 |
class SmolLM(nn.Module):
|
|
@@ -64,38 +66,13 @@ model = load_model()
|
|
| 64 |
model.train(False)
|
| 65 |
|
| 66 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
for _ in range(max_length):
|
| 75 |
-
if tokens.size(1) >= 1024: # GPT context length
|
| 76 |
-
break
|
| 77 |
-
|
| 78 |
-
logits = model(tokens)[0]
|
| 79 |
-
logits = logits[:, -1, :]
|
| 80 |
-
#logits = logits[:, -1, :] / temperature
|
| 81 |
-
probs = F.softmax(logits, dim=-1)
|
| 82 |
-
|
| 83 |
-
# Top-k sampling
|
| 84 |
-
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
|
| 85 |
-
ix = torch.multinomial(topk_probs, 1)
|
| 86 |
-
next_token = torch.gather(topk_indices, -1, ix)
|
| 87 |
-
|
| 88 |
-
tokens = torch.cat((tokens, next_token), dim=1)
|
| 89 |
-
|
| 90 |
-
# Remove special token check entirely
|
| 91 |
-
# Just generate for the specified length or until context limit
|
| 92 |
-
|
| 93 |
-
generated_texts = []
|
| 94 |
-
for i in range(num_samples):
|
| 95 |
-
text = enc.decode(tokens[i].tolist())
|
| 96 |
-
generated_texts.append(text)
|
| 97 |
-
|
| 98 |
-
return '\n\n---\n\n'.join(generated_texts)
|
| 99 |
|
| 100 |
# Create Gradio interface
|
| 101 |
iface = gr.Interface(
|
|
|
|
| 10 |
|
| 11 |
# Load the model from Hugging Face Hub
|
| 12 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 14 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 15 |
|
| 16 |
# Define the SmolLM2-135M model (a simplified version of a Transformer)
|
| 17 |
class SmolLM(nn.Module):
|
|
|
|
| 66 |
model.train(False)
|
| 67 |
|
| 68 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
| 69 |
+
|
| 70 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 71 |
+
outputs = model(input_ids)
|
| 72 |
+
predictions = torch.argmax(outputs, dim=-1)
|
| 73 |
+
decoded = tokenizer.decode(predictions[0], skip_special_tokens=True)
|
| 74 |
+
return decoded
|
| 75 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Create Gradio interface
|
| 78 |
iface = gr.Interface(
|