Update README.md
Browse files
README.md
CHANGED
|
@@ -51,13 +51,54 @@ While the current version (v1) produces stylistic, slightly "painterly" or "pixe
|
|
| 51 |
|
| 52 |
## 🛠️ How to use
|
| 53 |
```python
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
from safetensors.torch import load_file
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
## 🛠️ How to use
|
| 53 |
```python
|
| 54 |
+
import torch
|
| 55 |
+
from transformers import AutoTokenizer, CLIPTextModel, AutoConfig
|
| 56 |
+
from huggingface_hub import hf_hub_download
|
| 57 |
from safetensors.torch import load_file
|
| 58 |
+
import matplotlib.pyplot as plt
|
| 59 |
+
import numpy as np
|
| 60 |
+
|
| 61 |
+
# import classes
|
| 62 |
+
|
| 63 |
+
def generate_fixed_from_hub(prompt, model_id="TopAI-1/Pixel-1"):
|
| 64 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 65 |
+
print(f"🚀 Working on {device}...")
|
| 66 |
+
|
| 67 |
+
# 1. Download
|
| 68 |
+
print("📥 Downloading weights directly from Hub...")
|
| 69 |
+
weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors")
|
| 70 |
+
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
|
| 71 |
+
|
| 72 |
+
# 2. loading CLIP
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 74 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
| 75 |
+
|
| 76 |
+
# 3. load wieghts
|
| 77 |
+
config = TopAIImageConfig.from_pretrained(model_id)
|
| 78 |
+
model = TopAIImageGenerator(config)
|
| 79 |
+
|
| 80 |
+
state_dict = load_file(weights_path, device=device)
|
| 81 |
+
|
| 82 |
+
clean_sd = {k.replace('\xa0', ' '): v for k, v in state_dict.items()}
|
| 83 |
+
|
| 84 |
+
# Loading
|
| 85 |
+
model.load_state_dict(clean_sd, strict=False)
|
| 86 |
+
model.to(device).eval()
|
| 87 |
+
print("✅ Weights loaded perfectly!")
|
| 88 |
+
|
| 89 |
+
# 4. Generation
|
| 90 |
+
inputs = tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").to(device)
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
emb = text_encoder(inputs.input_ids).pooler_output
|
| 93 |
+
out = model(emb)
|
| 94 |
+
|
| 95 |
+
# 5. Show
|
| 96 |
+
img = (out.squeeze(0).cpu().permute(1, 2, 0).numpy() + 1.0) / 2.0
|
| 97 |
+
plt.figure(figsize=(8, 8))
|
| 98 |
+
plt.imshow(np.clip(img, 0, 1))
|
| 99 |
+
plt.axis('off')
|
| 100 |
+
plt.title(prompt)
|
| 101 |
+
plt.show()
|
| 102 |
+
|
| 103 |
+
# Run
|
| 104 |
+
generate_fixed_from_hub("Window with metal bars and fence shadow")
|