AlignDRAW-open / README.md
Bertug1911's picture
Update README.md
d0ddea4 verified
---
license: cc-by-nc-sa-4.0
language:
- en
pipeline_tag: text-to-image
---
# Model Card for Model ID
<!-- Provide a quick summary of what the model is/does. -->
This modelcard aims to be a base template for new models. It has been generated using [this raw template](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md?plain=1).
## Model Details
### Model Description
<!-- Provide a longer summary of what this model is. -->
AlignDRAW is first model can generates images from text but, you can't find this model on web anymore so we decided to do again in python!
We trained on hand-written numbers and prompts!
- **Developed by:** Bertug Gunel
- **Funded by [optional]:** NoBody
- **Shared by [optional]:** NoBody
- **Model type:** Attention + VAE + RNN
- **Language(s) (NLP):** EN
- **License:** cc-by-nc-sa-4.0
- **Finetuned from model [optional]:** NoBody
### Model Sources [optional]
<!-- Provide the basic links for the model. -->
- **Repository:** Cooming soon!
- **Paper [optional]:** Cooming soon!
- **Demo [optional]:** Cooming soon!
## Uses
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
### Direct Use
You can install weights, and embed head, direct use cooming soon!
CODE:
```
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from safetensors.torch import load_file
# —— Configurations ——
IMG_SIZE = 28
INPUT_DIM = IMG_SIZE * IMG_SIZE
LATENT_DIM = 100
TIMESTEPS = 10
CAPTION_EMBED_DIM = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# —— Model Definitions ——
import torch.nn as nn
import torch.nn.functional as F
class CaptionEmbed(nn.Module):
def __init__(self, num_classes=10, embed_dim=CAPTION_EMBED_DIM):
super().__init__()
self.embed = nn.Embedding(num_classes, embed_dim)
def forward(self, labels):
return self.embed(labels)
class DRAWTextModel(nn.Module):
def __init__(self, input_dim, latent_dim, timesteps, caption_dim):
super().__init__()
self.encoder = nn.LSTM(input_dim + caption_dim, 256)
self.decoder = nn.LSTM(latent_dim + caption_dim, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
self.fc_dec = nn.Linear(256, input_dim)
def forward(self, x_seq, cap_seq):
batch = x_seq.size(1)
canvas = torch.zeros_like(x_seq)
h_enc = (torch.zeros(1, batch, 256, device=x_seq.device),
torch.zeros(1, batch, 256, device=x_seq.device))
h_dec = (torch.zeros(1, batch, 256, device=x_seq.device),
torch.zeros(1, batch, 256, device=x_seq.device))
mus, logvars = [], []
for t in range(x_seq.size(0)):
diff = x_seq[t] - torch.sigmoid(canvas[t])
diff_cap = torch.cat([diff, cap_seq[t]], dim=-1).unsqueeze(0)
_, h_enc = self.encoder(diff_cap, h_enc)
enc_h = h_enc[0].squeeze(0)
mu = self.fc_mu(enc_h); logvar = self.fc_logvar(enc_h)
std = torch.exp(0.5 * logvar)
z = mu + std * torch.randn_like(std)
z_cap = torch.cat([z, cap_seq[t]], dim=-1).unsqueeze(0)
_, h_dec = self.decoder(z_cap, h_dec)
dec_h = h_dec[0].squeeze(0)
canvas[t] = canvas[t] + self.fc_dec(dec_h)
return canvas
# —— Load Pretrained Models ——
caption_model = CaptionEmbed().to(DEVICE)
model = DRAWTextModel(INPUT_DIM, LATENT_DIM, TIMESTEPS, CAPTION_EMBED_DIM).to(DEVICE)
caption_state = load_file("caption_embed.safetensors") #PATH TO EMBED HEAD IN YOUR PC
model_state = load_file("draw_model.safetensors") #PATH TO MODEL IN YOUR PC
caption_model.load_state_dict(caption_state)
model.load_state_dict(model_state)
caption_model.eval()
model.eval()
# —— Prompt Mapping ——
prompt2digit = {
"number zero": 0,
"number one": 1,
"number two": 2,
"number three": 3,
"number four": 4,
"number five": 5,
"number six": 6,
"number seven": 7,
"number eight": 8,
"number nine": 9
}
# —— Interactive Generation Loop ——
while True:
prompt = input("Prompt gir (örn: 'iki sayısı', çıkmak için 'q'): ").strip().lower()
if prompt == 'q':
print('Çıkış yapıldı.')
break
if prompt not in prompt2digit:
print(f"Bilinmeyen prompt: {prompt}")
continue
digit = prompt2digit[prompt]
labels = torch.tensor([digit], device=DEVICE)
caption_vec = caption_model(labels)
cap_seq = caption_vec.unsqueeze(0).repeat(TIMESTEPS, 1, 1)
# Generation
h_dec = (torch.zeros(1, 1, 256, device=DEVICE),
torch.zeros(1, 1, 256, device=DEVICE))
canvas = torch.zeros(TIMESTEPS, 1, INPUT_DIM, device=DEVICE)
for t in range(TIMESTEPS):
z = torch.randn(1, LATENT_DIM, device=DEVICE)
z_cap = torch.cat([z, cap_seq[t]], dim=-1).unsqueeze(0)
_, h_dec = model.decoder(z_cap, h_dec)
dec_h = h_dec[0].squeeze(0)
canvas[t] = canvas[t] + model.fc_dec(dec_h)
img = torch.sigmoid(canvas[-1]).view(1,1,IMG_SIZE,IMG_SIZE)
grid = make_grid(img.cpu(), normalize=True)
plt.figure(figsize=(3,3)); plt.axis('off'); plt.imshow(grid.permute(1,2,0)); plt.show()
```
### Downstream Use [optional]
You can't fine-tune! (Or can you?)
### Out-of-Scope Use
This model is only trained with: "MNIST" dataset (Handwritten numbers) and it can only generates numbers. FULL GENERATION cooming soon!
EMbed IDs:
0 = "0"
1 = "1"
2 = "2"
3 = "3"
4 = "4"
5 = "5"
6 = "6"
7 = "7"
8 = "8"
9 = "9"
No tokenization needed.
ID = Number Class
## Bias, Risks, and Limitations
No any risks!
## How to Get Started with the Model
Use the code below to get started with the model.
[More Information Needed]
## Training Details
### Training Data
Model trained on 1 epochs, on "MNIST" dataset.
#### Preprocessing [optional]
-
## Evaluation
No any tests!
### Testing Data, Factors & Metrics
#### Testing Data
No any tests!
#### Metrics
-Accuracy: Training accuracy.
### Results
Examples:
Input ID 0 (Prompt = "Number zero")
![indir.png](https://cdn-uploads.huggingface.co/production/uploads/66eeed95df079df1a35a553b/V2kgcXxzU6XgNnGQ2VO9d.png)
Input ID 3 (Prompt = "Number three")
![indir (1).png](https://cdn-uploads.huggingface.co/production/uploads/66eeed95df079df1a35a553b/nMeRK3kvB5vilVZuFExRz.png)
#### Summary
Model can generate good quality numbers (0,1,2,3,4,5,6,7,8,9)!
FULL version coomig soon!
## Environmental Impact
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
- **Hardware Type:** A100 40gb
- **Hours used:** <0.1
- **Cloud Provider:** Google COLAB
- **Compute Region:** -
- **Carbon Emitted:** -
## Technical Specifications [optional]
### Model Architecture and Objective
RRN + VAE + Attention
## Model Card Authors [optional]
Bertug Gunel
## Model Card Contact
bertugscpmail@gmail.com