Update README.md
#3
by
hypothetical
- opened
README.md
CHANGED
|
@@ -33,54 +33,32 @@ __Goals of elastic models:__
|
|
| 33 |
|
| 34 |
## Inference
|
| 35 |
|
| 36 |
-
To infer our models, you just need to replace `transformers` import with `elastic_models.
|
| 37 |
|
| 38 |
```python
|
| 39 |
import torch
|
| 40 |
-
from
|
| 41 |
-
from elastic_models.transformers import AutoModelForCausalLM
|
| 42 |
|
| 43 |
# Currently we require to have your HF token
|
| 44 |
# as we use original weights for part of layers and
|
| 45 |
# model confugaration as well
|
| 46 |
-
|
| 47 |
hf_token = ''
|
| 48 |
-
hf_cache_dir = ''
|
| 49 |
device = torch.device("cuda")
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
model_name, token=hf_token
|
| 54 |
-
)
|
| 55 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 56 |
-
model_name,
|
| 57 |
-
token=hf_token,
|
| 58 |
-
cache_dir=hf_cache_dir,
|
| 59 |
torch_dtype=torch.bfloat16,
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
with torch.inference_mode:
|
| 71 |
-
generate_ids = model.generate(**inputs, max_length=500)
|
| 72 |
-
|
| 73 |
-
input_len = inputs['input_ids'].shape[1]
|
| 74 |
-
generate_ids = generate_ids[:, input_len:]
|
| 75 |
-
output = tokenizer.batch_decode(
|
| 76 |
-
generate_ids,
|
| 77 |
-
skip_special_tokens=True,
|
| 78 |
-
clean_up_tokenization_spaces=False
|
| 79 |
-
)[0]
|
| 80 |
-
|
| 81 |
-
# Validate answer
|
| 82 |
-
print(f"# Q:\n{prompt}\n")
|
| 83 |
-
print(f"# A:\n{output}\n")
|
| 84 |
```
|
| 85 |
|
| 86 |
### Installation
|
|
|
|
| 33 |
|
| 34 |
## Inference
|
| 35 |
|
| 36 |
+
To infer our models, you just need to replace `transformers` import with `elastic_models.diffusers`:
|
| 37 |
|
| 38 |
```python
|
| 39 |
import torch
|
| 40 |
+
from elastic_models.diffusers import FluxPipeline
|
|
|
|
| 41 |
|
| 42 |
# Currently we require to have your HF token
|
| 43 |
# as we use original weights for part of layers and
|
| 44 |
# model confugaration as well
|
| 45 |
+
mode_name = 'black-forest-labs/FLUX.1-dev'
|
| 46 |
hf_token = ''
|
|
|
|
| 47 |
device = torch.device("cuda")
|
| 48 |
|
| 49 |
+
pipeline = FluxPipeline.from_pretrained(
|
| 50 |
+
mode_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
torch_dtype=torch.bfloat16,
|
| 52 |
+
mode='S'
|
| 53 |
+
)
|
| 54 |
+
pipeline.to(device)
|
| 55 |
+
|
| 56 |
+
prompts = ["Kitten eating a banana"]
|
| 57 |
+
|
| 58 |
+
output = pipeline(prompt=prompts)
|
| 59 |
+
|
| 60 |
+
for prompt, output_image in zip(prompts, output.images):
|
| 61 |
+
output_image.save((prompt.replace(' ', '_') + '.png'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
```
|
| 63 |
|
| 64 |
### Installation
|