Monimoy commited on
Commit
ab1e906
·
verified ·
1 Parent(s): 526bed7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -32
app.py CHANGED
@@ -10,6 +10,8 @@ import os
10
 
11
  # Load the model from Hugging Face Hub
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
13
 
14
  # Define the SmolLM2-135M model (a simplified version of a Transformer)
15
  class SmolLM(nn.Module):
@@ -64,38 +66,13 @@ model = load_model()
64
  model.train(False)
65
 
66
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
67
- enc = tiktoken.get_encoding('gpt2')
68
- tokens = enc.encode(prompt)
69
- tokens = torch.tensor(tokens, dtype=torch.long)
70
- tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
71
- tokens = tokens.to(device)
72
-
73
- with torch.no_grad():
74
- for _ in range(max_length):
75
- if tokens.size(1) >= 1024: # GPT context length
76
- break
77
-
78
- logits = model(tokens)[0]
79
- logits = logits[:, -1, :]
80
- #logits = logits[:, -1, :] / temperature
81
- probs = F.softmax(logits, dim=-1)
82
-
83
- # Top-k sampling
84
- topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
85
- ix = torch.multinomial(topk_probs, 1)
86
- next_token = torch.gather(topk_indices, -1, ix)
87
-
88
- tokens = torch.cat((tokens, next_token), dim=1)
89
-
90
- # Remove special token check entirely
91
- # Just generate for the specified length or until context limit
92
-
93
- generated_texts = []
94
- for i in range(num_samples):
95
- text = enc.decode(tokens[i].tolist())
96
- generated_texts.append(text)
97
-
98
- return '\n\n---\n\n'.join(generated_texts)
99
 
100
  # Create Gradio interface
101
  iface = gr.Interface(
 
10
 
11
  # Load the model from Hugging Face Hub
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
14
+ tokenizer.pad_token = tokenizer.eos_token
15
 
16
  # Define the SmolLM2-135M model (a simplified version of a Transformer)
17
  class SmolLM(nn.Module):
 
66
  model.train(False)
67
 
68
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
69
+
70
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
71
+ outputs = model(input_ids)
72
+ predictions = torch.argmax(outputs, dim=-1)
73
+ decoded = tokenizer.decode(predictions[0], skip_special_tokens=True)
74
+ return decoded
75
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Create Gradio interface
78
  iface = gr.Interface(