Files changed (1) hide show
  1. README.md +15 -37
README.md CHANGED
@@ -33,54 +33,32 @@ __Goals of elastic models:__
33
 
34
  ## Inference
35
 
36
- To infer our models, you just need to replace `transformers` import with `elastic_models.transformers`:
37
 
38
  ```python
39
  import torch
40
- from transformers import AutoTokenizer
41
- from elastic_models.transformers import AutoModelForCausalLM
42
 
43
  # Currently we require to have your HF token
44
  # as we use original weights for part of layers and
45
  # model confugaration as well
46
- model_name = "mistralai/Mistral-7B-Instruct-v0.3"
47
  hf_token = ''
48
- hf_cache_dir = ''
49
  device = torch.device("cuda")
50
 
51
- # Create mode
52
- tokenizer = AutoTokenizer.from_pretrained(
53
- model_name, token=hf_token
54
- )
55
- model = AutoModelForCausalLM.from_pretrained(
56
- model_name,
57
- token=hf_token,
58
- cache_dir=hf_cache_dir,
59
  torch_dtype=torch.bfloat16,
60
- attn_implementation="sdpa",
61
- mode='s'
62
- ).to(device)
63
- model.generation_config.pad_token_id = tokenizer.eos_token_id
64
-
65
- # Inference simple as transformers library
66
- prompt = "Describe basics of DNNs quantization."
67
- inputs = tokenizer(prompt, return_tensors="pt")
68
- inputs.to(device)
69
-
70
- with torch.inference_mode:
71
- generate_ids = model.generate(**inputs, max_length=500)
72
-
73
- input_len = inputs['input_ids'].shape[1]
74
- generate_ids = generate_ids[:, input_len:]
75
- output = tokenizer.batch_decode(
76
- generate_ids,
77
- skip_special_tokens=True,
78
- clean_up_tokenization_spaces=False
79
- )[0]
80
-
81
- # Validate answer
82
- print(f"# Q:\n{prompt}\n")
83
- print(f"# A:\n{output}\n")
84
  ```
85
 
86
  ### Installation
 
33
 
34
  ## Inference
35
 
36
+ To infer our models, you just need to replace `transformers` import with `elastic_models.diffusers`:
37
 
38
  ```python
39
  import torch
40
+ from elastic_models.diffusers import FluxPipeline
 
41
 
42
  # Currently we require to have your HF token
43
  # as we use original weights for part of layers and
44
  # model confugaration as well
45
+ mode_name = 'black-forest-labs/FLUX.1-dev'
46
  hf_token = ''
 
47
  device = torch.device("cuda")
48
 
49
+ pipeline = FluxPipeline.from_pretrained(
50
+ mode_name,
 
 
 
 
 
 
51
  torch_dtype=torch.bfloat16,
52
+ mode='S'
53
+ )
54
+ pipeline.to(device)
55
+
56
+ prompts = ["Kitten eating a banana"]
57
+
58
+ output = pipeline(prompt=prompts)
59
+
60
+ for prompt, output_image in zip(prompts, output.images):
61
+ output_image.save((prompt.replace(' ', '_') + '.png'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ```
63
 
64
  ### Installation