Update código de modelo
Browse files- código de modelo +4 -4
código de modelo
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
import torch
|
| 2 |
-
from transformers import
|
| 3 |
|
| 4 |
class TextToImageGenerator(torch.nn.Module):
|
| 5 |
-
def __init__(self, model_name="
|
| 6 |
super(TextToImageGenerator, self).__init__()
|
| 7 |
-
self.tokenizer =
|
| 8 |
self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
|
| 9 |
|
| 10 |
def forward(self, input_text):
|
| 11 |
input_ids = self.tokenizer(input_text, return_tensors="pt")["input_ids"]
|
| 12 |
-
output = self.
|
| 13 |
return output.logits
|
| 14 |
|
| 15 |
# Instanciar el modelo
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from transformers import GPT3LMHeadModel, GPT3Tokenizer
|
| 3 |
|
| 4 |
class TextToImageGenerator(torch.nn.Module):
|
| 5 |
+
def __init__(self, model_name="gpt3"):
|
| 6 |
super(TextToImageGenerator, self).__init__()
|
| 7 |
+
self.tokenizer = GPT3Tokenizer.from_pretrained(model_name)
|
| 8 |
self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
|
| 9 |
|
| 10 |
def forward(self, input_text):
|
| 11 |
input_ids = self.tokenizer(input_text, return_tensors="pt")["input_ids"]
|
| 12 |
+
output = self.gpt3(input_ids, return_dict=True)
|
| 13 |
return output.logits
|
| 14 |
|
| 15 |
# Instanciar el modelo
|