farapart commited on
Commit
7c6c188
·
verified ·
1 Parent(s): 0989d4b

Initial commit with folder contents

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -3
  2. pyproject.toml +3 -4
  3. src/pipeline.py +8 -42
.gitattributes CHANGED
@@ -32,6 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- RobertML.png filter=lfs diff=lfs merge=lfs -text
37
- backup.png filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
pyproject.toml CHANGED
@@ -27,14 +27,13 @@ repository = "black-forest-labs/FLUX.1-schnell"
27
  revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
  exclude = ["transformer"]
29
 
30
- [[tool.edge-maxxing.models]]
31
- repository = "madebyollin/taef1"
32
- revision = "2d552378e58c9c94201075708d7de4e1163b2689"
33
-
34
  [[tool.edge-maxxing.models]]
35
  repository = "farapart/t5_encoder"
36
  revision = "c225a976e16b77764f653a801268de86e20adb84"
37
 
 
 
 
38
 
39
  [project.scripts]
40
  start_inference = "main:main"
 
27
  revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
  exclude = ["transformer"]
29
 
 
 
 
 
30
  [[tool.edge-maxxing.models]]
31
  repository = "farapart/t5_encoder"
32
  revision = "c225a976e16b77764f653a801268de86e20adb84"
33
 
34
+ [[tool.edge-maxxing.models]]
35
+ repository = "madebyollin/taef1"
36
+ revision = "2d552378e58c9c94201075708d7de4e1163b2689"
37
 
38
  [project.scripts]
39
  start_inference = "main:main"
src/pipeline.py CHANGED
@@ -1,55 +1,26 @@
1
- from diffusers import (
2
- DiffusionPipeline,
3
- AutoencoderKL,
4
- FluxPipeline,
5
- FluxTransformer2DModel
6
- )
7
- from diffusers.image_processor import VaeImageProcessor
8
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
9
- from huggingface_hub.constants import HF_HUB_CACHE
10
- from transformers import (
11
- T5EncoderModel,
12
- T5TokenizerFast,
13
- CLIPTokenizer,
14
- CLIPTextModel
15
- )
16
  import torch
17
  import torch._dynamo
18
- import gc
 
19
  from PIL import Image
20
  from pipelines.models import TextToImageRequest
21
  from torch import Generator
22
- import time
23
- import math
24
- from typing import Type, Dict, Any, Tuple, Callable, Optional, Union
25
- import numpy as np
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
29
 
30
- # preconfigs
31
- import os
32
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
33
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
34
  torch._dynamo.config.suppress_errors = True
35
  torch.backends.cuda.matmul.allow_tf32 = True
36
  torch.backends.cudnn.enabled = True
37
- # torch.backends.cudnn.benchmark = True
38
 
39
- # globals
40
  Pipeline = None
41
- ckpt_id = "black-forest-labs/FLUX.1-schnell"
42
- ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
43
- TinyVAE = "madebyollin/taef1"
44
- TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
45
-
46
- def empty_cache():
47
- gc.collect()
48
- torch.cuda.empty_cache()
49
- torch.cuda.reset_max_memory_allocated()
50
- torch.cuda.reset_peak_memory_stats()
51
 
52
  def load_pipeline() -> Pipeline:
 
 
53
  text_encoder_2 = T5EncoderModel.from_pretrained("farapart/t5_encoder", revision = "c225a976e16b77764f653a801268de86e20adb84", subfolder="text_encoder_2",torch_dtype=torch.bfloat16)
54
  path = os.path.join(HF_HUB_CACHE, "models--farapart--t5_encoder/snapshots/c225a976e16b77764f653a801268de86e20adb84/transformer")
55
  transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
@@ -60,11 +31,6 @@ def load_pipeline() -> Pipeline:
60
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
61
  return pipeline
62
 
63
- sample = 1
64
  @torch.no_grad()
65
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
66
- global sample
67
- if not sample:
68
- sample=1
69
- empty_cache()
70
  return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch._dynamo
3
+ import os
4
+ import torch.nn.functional as F
5
  from PIL import Image
6
  from pipelines.models import TextToImageRequest
7
  from torch import Generator
8
+ from typing import Type
9
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
10
+ from huggingface_hub.constants import HF_HUB_CACHE
11
+ from transformers import T5EncoderModel
 
 
 
12
 
 
 
13
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
14
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
15
  torch._dynamo.config.suppress_errors = True
16
  torch.backends.cuda.matmul.allow_tf32 = True
17
  torch.backends.cudnn.enabled = True
 
18
 
 
19
  Pipeline = None
 
 
 
 
 
 
 
 
 
 
20
 
21
  def load_pipeline() -> Pipeline:
22
+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
23
+ ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
24
  text_encoder_2 = T5EncoderModel.from_pretrained("farapart/t5_encoder", revision = "c225a976e16b77764f653a801268de86e20adb84", subfolder="text_encoder_2",torch_dtype=torch.bfloat16)
25
  path = os.path.join(HF_HUB_CACHE, "models--farapart--t5_encoder/snapshots/c225a976e16b77764f653a801268de86e20adb84/transformer")
26
  transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
 
31
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
32
  return pipeline
33
 
 
34
  @torch.no_grad()
35
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
 
 
 
 
36
  return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]