|
|
--- |
|
|
license: cc-by-nc-sa-4.0 |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-to-image |
|
|
--- |
|
|
|
|
|
# Model Card for Model ID |
|
|
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
|
|
This modelcard aims to be a base template for new models. It has been generated using [this raw template](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md?plain=1). |
|
|
|
|
|
|
|
|
## Model Details |
|
|
|
|
|
### Model Description |
|
|
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
AlignDRAW is first model can generates images from text but, you can't find this model on web anymore so we decided to do again in python! |
|
|
We trained on hand-written numbers and prompts! |
|
|
|
|
|
- **Developed by:** Bertug Gunel |
|
|
- **Funded by [optional]:** NoBody |
|
|
- **Shared by [optional]:** NoBody |
|
|
- **Model type:** Attention + VAE + RNN |
|
|
- **Language(s) (NLP):** EN |
|
|
- **License:** cc-by-nc-sa-4.0 |
|
|
- **Finetuned from model [optional]:** NoBody |
|
|
|
|
|
### Model Sources [optional] |
|
|
|
|
|
<!-- Provide the basic links for the model. --> |
|
|
|
|
|
- **Repository:** Cooming soon! |
|
|
- **Paper [optional]:** Cooming soon! |
|
|
- **Demo [optional]:** Cooming soon! |
|
|
|
|
|
## Uses |
|
|
|
|
|
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
|
|
|
|
|
### Direct Use |
|
|
|
|
|
You can install weights, and embed head, direct use cooming soon! |
|
|
|
|
|
CODE: |
|
|
|
|
|
``` |
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision.utils import make_grid |
|
|
import matplotlib.pyplot as plt |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
# —— Configurations —— |
|
|
IMG_SIZE = 28 |
|
|
INPUT_DIM = IMG_SIZE * IMG_SIZE |
|
|
LATENT_DIM = 100 |
|
|
TIMESTEPS = 10 |
|
|
CAPTION_EMBED_DIM = 50 |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
# —— Model Definitions —— |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class CaptionEmbed(nn.Module): |
|
|
def __init__(self, num_classes=10, embed_dim=CAPTION_EMBED_DIM): |
|
|
super().__init__() |
|
|
self.embed = nn.Embedding(num_classes, embed_dim) |
|
|
def forward(self, labels): |
|
|
return self.embed(labels) |
|
|
|
|
|
class DRAWTextModel(nn.Module): |
|
|
def __init__(self, input_dim, latent_dim, timesteps, caption_dim): |
|
|
super().__init__() |
|
|
self.encoder = nn.LSTM(input_dim + caption_dim, 256) |
|
|
self.decoder = nn.LSTM(latent_dim + caption_dim, 256) |
|
|
self.fc_mu = nn.Linear(256, latent_dim) |
|
|
self.fc_logvar = nn.Linear(256, latent_dim) |
|
|
self.fc_dec = nn.Linear(256, input_dim) |
|
|
|
|
|
def forward(self, x_seq, cap_seq): |
|
|
batch = x_seq.size(1) |
|
|
canvas = torch.zeros_like(x_seq) |
|
|
h_enc = (torch.zeros(1, batch, 256, device=x_seq.device), |
|
|
torch.zeros(1, batch, 256, device=x_seq.device)) |
|
|
h_dec = (torch.zeros(1, batch, 256, device=x_seq.device), |
|
|
torch.zeros(1, batch, 256, device=x_seq.device)) |
|
|
mus, logvars = [], [] |
|
|
for t in range(x_seq.size(0)): |
|
|
diff = x_seq[t] - torch.sigmoid(canvas[t]) |
|
|
diff_cap = torch.cat([diff, cap_seq[t]], dim=-1).unsqueeze(0) |
|
|
_, h_enc = self.encoder(diff_cap, h_enc) |
|
|
enc_h = h_enc[0].squeeze(0) |
|
|
mu = self.fc_mu(enc_h); logvar = self.fc_logvar(enc_h) |
|
|
std = torch.exp(0.5 * logvar) |
|
|
z = mu + std * torch.randn_like(std) |
|
|
z_cap = torch.cat([z, cap_seq[t]], dim=-1).unsqueeze(0) |
|
|
_, h_dec = self.decoder(z_cap, h_dec) |
|
|
dec_h = h_dec[0].squeeze(0) |
|
|
canvas[t] = canvas[t] + self.fc_dec(dec_h) |
|
|
return canvas |
|
|
|
|
|
# —— Load Pretrained Models —— |
|
|
caption_model = CaptionEmbed().to(DEVICE) |
|
|
model = DRAWTextModel(INPUT_DIM, LATENT_DIM, TIMESTEPS, CAPTION_EMBED_DIM).to(DEVICE) |
|
|
|
|
|
caption_state = load_file("caption_embed.safetensors") #PATH TO EMBED HEAD IN YOUR PC |
|
|
model_state = load_file("draw_model.safetensors") #PATH TO MODEL IN YOUR PC |
|
|
caption_model.load_state_dict(caption_state) |
|
|
model.load_state_dict(model_state) |
|
|
|
|
|
caption_model.eval() |
|
|
model.eval() |
|
|
|
|
|
# —— Prompt Mapping —— |
|
|
prompt2digit = { |
|
|
"number zero": 0, |
|
|
"number one": 1, |
|
|
"number two": 2, |
|
|
"number three": 3, |
|
|
"number four": 4, |
|
|
"number five": 5, |
|
|
"number six": 6, |
|
|
"number seven": 7, |
|
|
"number eight": 8, |
|
|
"number nine": 9 |
|
|
} |
|
|
# —— Interactive Generation Loop —— |
|
|
while True: |
|
|
prompt = input("Prompt gir (örn: 'iki sayısı', çıkmak için 'q'): ").strip().lower() |
|
|
if prompt == 'q': |
|
|
print('Çıkış yapıldı.') |
|
|
break |
|
|
if prompt not in prompt2digit: |
|
|
print(f"Bilinmeyen prompt: {prompt}") |
|
|
continue |
|
|
|
|
|
digit = prompt2digit[prompt] |
|
|
labels = torch.tensor([digit], device=DEVICE) |
|
|
caption_vec = caption_model(labels) |
|
|
cap_seq = caption_vec.unsqueeze(0).repeat(TIMESTEPS, 1, 1) |
|
|
|
|
|
# Generation |
|
|
h_dec = (torch.zeros(1, 1, 256, device=DEVICE), |
|
|
torch.zeros(1, 1, 256, device=DEVICE)) |
|
|
canvas = torch.zeros(TIMESTEPS, 1, INPUT_DIM, device=DEVICE) |
|
|
for t in range(TIMESTEPS): |
|
|
z = torch.randn(1, LATENT_DIM, device=DEVICE) |
|
|
z_cap = torch.cat([z, cap_seq[t]], dim=-1).unsqueeze(0) |
|
|
_, h_dec = model.decoder(z_cap, h_dec) |
|
|
dec_h = h_dec[0].squeeze(0) |
|
|
canvas[t] = canvas[t] + model.fc_dec(dec_h) |
|
|
|
|
|
img = torch.sigmoid(canvas[-1]).view(1,1,IMG_SIZE,IMG_SIZE) |
|
|
grid = make_grid(img.cpu(), normalize=True) |
|
|
plt.figure(figsize=(3,3)); plt.axis('off'); plt.imshow(grid.permute(1,2,0)); plt.show() |
|
|
``` |
|
|
|
|
|
### Downstream Use [optional] |
|
|
|
|
|
You can't fine-tune! (Or can you?) |
|
|
|
|
|
### Out-of-Scope Use |
|
|
|
|
|
This model is only trained with: "MNIST" dataset (Handwritten numbers) and it can only generates numbers. FULL GENERATION cooming soon! |
|
|
|
|
|
EMbed IDs: |
|
|
0 = "0" |
|
|
1 = "1" |
|
|
2 = "2" |
|
|
3 = "3" |
|
|
4 = "4" |
|
|
5 = "5" |
|
|
6 = "6" |
|
|
7 = "7" |
|
|
8 = "8" |
|
|
9 = "9" |
|
|
|
|
|
No tokenization needed. |
|
|
ID = Number Class |
|
|
|
|
|
|
|
|
|
|
|
## Bias, Risks, and Limitations |
|
|
|
|
|
No any risks! |
|
|
|
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
|
|
Use the code below to get started with the model. |
|
|
|
|
|
[More Information Needed] |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Training Data |
|
|
|
|
|
Model trained on 1 epochs, on "MNIST" dataset. |
|
|
|
|
|
|
|
|
#### Preprocessing [optional] |
|
|
|
|
|
- |
|
|
|
|
|
|
|
|
## Evaluation |
|
|
|
|
|
No any tests! |
|
|
|
|
|
### Testing Data, Factors & Metrics |
|
|
|
|
|
#### Testing Data |
|
|
|
|
|
No any tests! |
|
|
|
|
|
|
|
|
#### Metrics |
|
|
|
|
|
-Accuracy: Training accuracy. |
|
|
|
|
|
### Results |
|
|
|
|
|
Examples: |
|
|
|
|
|
Input ID 0 (Prompt = "Number zero") |
|
|
|
|
|
 |
|
|
|
|
|
Input ID 3 (Prompt = "Number three") |
|
|
|
|
|
 |
|
|
#### Summary |
|
|
Model can generate good quality numbers (0,1,2,3,4,5,6,7,8,9)! |
|
|
FULL version coomig soon! |
|
|
|
|
|
## Environmental Impact |
|
|
|
|
|
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly --> |
|
|
|
|
|
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). |
|
|
|
|
|
- **Hardware Type:** A100 40gb |
|
|
- **Hours used:** <0.1 |
|
|
- **Cloud Provider:** Google COLAB |
|
|
- **Compute Region:** - |
|
|
- **Carbon Emitted:** - |
|
|
|
|
|
## Technical Specifications [optional] |
|
|
|
|
|
### Model Architecture and Objective |
|
|
|
|
|
RRN + VAE + Attention |
|
|
|
|
|
|
|
|
## Model Card Authors [optional] |
|
|
|
|
|
Bertug Gunel |
|
|
|
|
|
## Model Card Contact |
|
|
|
|
|
bertugscpmail@gmail.com |