Cristian283 commited on
Commit
14e0277
·
verified ·
1 Parent(s): 3a959c0

Update código de modelo

Browse files
Files changed (1) hide show
  1. código de modelo +4 -4
código de modelo CHANGED
@@ -1,15 +1,15 @@
1
  import torch
2
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
  class TextToImageGenerator(torch.nn.Module):
5
- def __init__(self, model_name="gpt2"):
6
  super(TextToImageGenerator, self).__init__()
7
- self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
  self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
9
 
10
  def forward(self, input_text):
11
  input_ids = self.tokenizer(input_text, return_tensors="pt")["input_ids"]
12
- output = self.gpt2(input_ids, return_dict=True)
13
  return output.logits
14
 
15
  # Instanciar el modelo
 
1
  import torch
2
+ from transformers import GPT3LMHeadModel, GPT3Tokenizer
3
 
4
  class TextToImageGenerator(torch.nn.Module):
5
+ def __init__(self, model_name="gpt3"):
6
  super(TextToImageGenerator, self).__init__()
7
+ self.tokenizer = GPT3Tokenizer.from_pretrained(model_name)
8
  self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
9
 
10
  def forward(self, input_text):
11
  input_ids = self.tokenizer(input_text, return_tensors="pt")["input_ids"]
12
+ output = self.gpt3(input_ids, return_dict=True)
13
  return output.logits
14
 
15
  # Instanciar el modelo