itriedcoding commited on
Commit
67c81b2
·
verified ·
1 Parent(s): bb8b351

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +69 -14
app.py CHANGED
@@ -1,22 +1,77 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
4
 
5
- model_name = "itriedcoding/Sage"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def generate_text(prompt, max_length, temperature):
10
- inputs = tokenizer.encode(prompt, return_tensors="pt")
 
 
11
  with torch.no_grad():
12
- outputs = model.generate(
13
- inputs,
14
- max_length=int(max_length),
15
- temperature=temperature,
16
- do_sample=True,
17
- pad_token_id=tokenizer.eos_token_id
18
- )
19
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
20
 
21
  demo = gr.Interface(
22
  fn=generate_text,
@@ -27,7 +82,7 @@ demo = gr.Interface(
27
  ],
28
  outputs=gr.Textbox(label="Generated Text"),
29
  title="Sage Text Generator",
30
- description="Generate text using the Sage custom character-level language model. Built from scratch with PyTorch.",
31
  examples=[
32
  ["Hello", 30, 0.8],
33
  ["The weather", 30, 0.8],
 
1
  import gradio as gr
 
2
  import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import pickle
6
+ import json
7
+ from huggingface_hub import hf_hub_download
8
 
9
+ REPO_ID = "itriedcoding/Sage"
10
+
11
+ # Custom model class matching Sage architecture
12
+ class TransformerLM(nn.Module):
13
+ def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
14
+ super().__init__()
15
+ self.embedding = nn.Embedding(vocab_size, d_model)
16
+ self.pos_embedding = nn.Embedding(max_seq_length, d_model)
17
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=0.1)
18
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
19
+ self.output_layer = nn.Linear(d_model, vocab_size)
20
+ self.max_seq_length = max_seq_length
21
+ self.vocab_size = vocab_size
22
+
23
+ def forward(self, src):
24
+ seq_len = src.size(1)
25
+ pos = torch.arange(0, seq_len, device=src.device).unsqueeze(0)
26
+ src_emb = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
27
+ pos_emb = self.pos_embedding(pos)
28
+ src_emb = src_emb + pos_emb
29
+ output = self.transformer_encoder(src_emb)
30
+ logits = self.output_layer(output)
31
+ return logits
32
+
33
+ # Download model files from Hugging Face
34
+ print("Downloading model files...")
35
+ config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
36
+ state_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model_state.bin")
37
+ tok_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl")
38
+
39
+ # Load config
40
+ with open(config_path) as f:
41
+ config = json.load(f)
42
+
43
+ # Load tokenizer
44
+ with open(tok_path, 'rb') as f:
45
+ tokenizer = pickle.load(f)
46
+
47
+ # Load model
48
+ model = TransformerLM(
49
+ vocab_size=config['vocab_size'],
50
+ d_model=config['hidden_size'],
51
+ nhead=config['num_attention_heads'],
52
+ num_layers=config['num_hidden_layers'],
53
+ dim_feedforward=config['intermediate_size'],
54
+ max_seq_length=config['max_position_embeddings']
55
+ )
56
+ state_dict = torch.load(state_path, map_location='cpu', weights_only=True)
57
+ model.load_state_dict(state_dict, strict=False)
58
+ model.eval()
59
 
60
  def generate_text(prompt, max_length, temperature):
61
+ input_ids = tokenizer.encode(prompt, max_length=32, padding=False, truncation=False, return_tensors='pt')
62
+ generated = input_ids.clone()
63
+
64
  with torch.no_grad():
65
+ for _ in range(int(max_length)):
66
+ logits = model(generated)
67
+ next_logits = logits[0, -1, :] / temperature
68
+ probs = torch.softmax(next_logits, dim=-1)
69
+ next_token = torch.multinomial(probs, num_samples=1)
70
+ generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
71
+ if next_token.item() == tokenizer.char_to_idx.get('.', 0):
72
+ break
73
+
74
+ return tokenizer.decode(generated[0])
75
 
76
  demo = gr.Interface(
77
  fn=generate_text,
 
82
  ],
83
  outputs=gr.Textbox(label="Generated Text"),
84
  title="Sage Text Generator",
85
+ description="Custom character-level language model built from scratch with PyTorch.",
86
  examples=[
87
  ["Hello", 30, 0.8],
88
  ["The weather", 30, 0.8],