Raziel1234 commited on
Commit
5331054
ยท
verified ยท
1 Parent(s): 8c7686d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -25
README.md CHANGED
@@ -52,53 +52,52 @@ While the current version (v1) produces stylistic, slightly "painterly" or "pixe
52
  ## ๐Ÿ› ๏ธ How to use
53
  ```python
54
  import torch
55
- from transformers import AutoTokenizer, CLIPTextModel, AutoConfig
56
- from huggingface_hub import hf_hub_download
57
- from safetensors.torch import load_file
58
  import matplotlib.pyplot as plt
59
  import numpy as np
60
-
61
- # import classes
 
62
 
63
  def generate_fixed_from_hub(prompt, model_id="TopAI-1/Pixel-1"):
64
  device = "cuda" if torch.cuda.is_available() else "cpu"
65
  print(f"๐Ÿš€ Working on {device}...")
66
 
67
- # 1. Download
68
- print("๐Ÿ“ฅ Downloading weights directly from Hub...")
69
- weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors")
70
- config_path = hf_hub_download(repo_id=model_id, filename="config.json")
 
71
 
72
- # 2. loading CLIP
73
  tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
74
  text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
75
 
76
- # 3. load wieghts
77
- config = TopAIImageConfig.from_pretrained(model_id)
78
- model = TopAIImageGenerator(config)
79
-
80
- state_dict = load_file(weights_path, device=device)
 
 
 
81
 
82
- clean_sd = {k.replace('\xa0', ' '): v for k, v in state_dict.items()}
83
-
84
- # Loading
85
- model.load_state_dict(clean_sd, strict=False)
86
- model.to(device).eval()
87
- print("โœ… Weights loaded perfectly!")
88
 
89
- # 4. Generation
 
90
  inputs = tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").to(device)
 
91
  with torch.no_grad():
92
  emb = text_encoder(inputs.input_ids).pooler_output
93
  out = model(emb)
94
 
95
- # 5. Show
96
  img = (out.squeeze(0).cpu().permute(1, 2, 0).numpy() + 1.0) / 2.0
97
  plt.figure(figsize=(8, 8))
98
  plt.imshow(np.clip(img, 0, 1))
99
  plt.axis('off')
100
- plt.title(prompt)
101
  plt.show()
102
 
103
- # Run
104
  generate_fixed_from_hub("Window with metal bars and fence shadow")
 
52
  ## ๐Ÿ› ๏ธ How to use
53
  ```python
54
  import torch
 
 
 
55
  import matplotlib.pyplot as plt
56
  import numpy as np
57
+ import os
58
+ import shutil
59
+ from transformers import AutoTokenizer, CLIPTextModel, AutoModel, AutoConfig
60
 
61
  def generate_fixed_from_hub(prompt, model_id="TopAI-1/Pixel-1"):
62
  device = "cuda" if torch.cuda.is_available() else "cpu"
63
  print(f"๐Ÿš€ Working on {device}...")
64
 
65
+ # 1. ื ื™ืงื•ื™ Cache ื›ื“ื™ ืœื•ื•ื“ื ืฉืืชื” ืžื•ืฉืš ืืช ื”ืชื™ืงื•ื ื™ื ื”ื—ื“ืฉื™ื ืžื”-Hub
66
+ cache_path = os.path.expanduser(f"~/.cache/huggingface/hub/models--{model_id.replace('/', '--')}")
67
+ if os.path.exists(cache_path):
68
+ print("๐Ÿงน Clearing old cache to fetch your latest fixes...")
69
+ shutil.rmtree(cache_path)
70
 
71
+ # 2. ื˜ืขื™ื ืช CLIP
72
  tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
73
  text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
74
 
75
+ # 3. ื˜ืขื™ื ืช ื”ืžื•ื“ืœ ื•ื”ืงื•ื ืคื™ื’ ืื•ื˜ื•ืžื˜ื™ืช ืžื”-Hub
76
+ # ื‘ื–ื›ื•ืช ื”-auto_map ื‘-config.json, transformers ื™ืžืฆื ืœื‘ื“ ืืช ื”ืžื—ืœืงื•ืช
77
+ print("๐Ÿ“ฅ Downloading architecture and weights directly from Hub...")
78
+ model = AutoModel.from_pretrained(
79
+ model_id,
80
+ trust_remote_code=True,
81
+ force_download=True
82
+ ).to(device)
83
 
84
+ model.eval()
85
+ print("โœ… Model loaded successfully!")
 
 
 
 
86
 
87
+ # 4. ื™ืฆื™ืจื”
88
+ print(f"๐ŸŽจ Generating: {prompt}")
89
  inputs = tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt").to(device)
90
+
91
  with torch.no_grad():
92
  emb = text_encoder(inputs.input_ids).pooler_output
93
  out = model(emb)
94
 
95
+ # 5. ืชืฆื•ื’ื”
96
  img = (out.squeeze(0).cpu().permute(1, 2, 0).numpy() + 1.0) / 2.0
97
  plt.figure(figsize=(8, 8))
98
  plt.imshow(np.clip(img, 0, 1))
99
  plt.axis('off')
 
100
  plt.show()
101
 
102
+ # ื”ืจืฆื”
103
  generate_fixed_from_hub("Window with metal bars and fence shadow")