Spaces:
Runtime error
Runtime error
style: use isort
Browse files- .github/workflows/black.yml +0 -14
- app/gradio/app_gradio.py +6 -14
- app/streamlit/app.py +2 -1
- app/streamlit/backend.py +3 -2
- dalle_mini/data.py +4 -2
- dalle_mini/model.py +6 -8
- dalle_mini/text.py +5 -3
- tools/train/train.py +8 -14
.github/workflows/black.yml
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
name: Lint
|
| 2 |
-
|
| 3 |
-
on:
|
| 4 |
-
push:
|
| 5 |
-
branches: [main]
|
| 6 |
-
pull_request:
|
| 7 |
-
branches: [main]
|
| 8 |
-
|
| 9 |
-
jobs:
|
| 10 |
-
lint:
|
| 11 |
-
runs-on: ubuntu-latest
|
| 12 |
-
steps:
|
| 13 |
-
- uses: actions/checkout@v2
|
| 14 |
-
- uses: psf/black@stable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/gradio/app_gradio.py
CHANGED
|
@@ -7,26 +7,18 @@
|
|
| 7 |
|
| 8 |
import random
|
| 9 |
|
|
|
|
| 10 |
import jax
|
| 11 |
-
import
|
| 12 |
-
from flax.training.common_utils import shard
|
| 13 |
from flax.jax_utils import replicate
|
| 14 |
-
|
| 15 |
-
from transformers import BartTokenizer
|
| 16 |
-
|
| 17 |
from PIL import Image, ImageDraw, ImageFont
|
| 18 |
-
import numpy as np
|
| 19 |
-
|
| 20 |
-
from vqgan_jax.modeling_flax_vqgan import VQModel
|
| 21 |
-
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
| 22 |
|
| 23 |
# ## CLIP Scoring
|
| 24 |
-
from transformers import CLIPProcessor, FlaxCLIPModel
|
| 25 |
-
|
| 26 |
-
import gradio as gr
|
| 27 |
-
|
| 28 |
-
from PIL import Image, ImageDraw, ImageFont
|
| 29 |
|
|
|
|
| 30 |
|
| 31 |
DALLE_REPO = "flax-community/dalle-mini"
|
| 32 |
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
|
|
|
| 7 |
|
| 8 |
import random
|
| 9 |
|
| 10 |
+
import gradio as gr
|
| 11 |
import jax
|
| 12 |
+
import numpy as np
|
|
|
|
| 13 |
from flax.jax_utils import replicate
|
| 14 |
+
from flax.training.common_utils import shard
|
|
|
|
|
|
|
| 15 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# ## CLIP Scoring
|
| 18 |
+
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
|
| 19 |
+
from vqgan_jax.modeling_flax_vqgan import VQModel
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
| 22 |
|
| 23 |
DALLE_REPO = "flax-community/dalle-mini"
|
| 24 |
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
app/streamlit/app.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding: utf-8
|
| 3 |
|
| 4 |
-
from .backend import ServiceError, get_images_from_backend
|
| 5 |
import streamlit as st
|
| 6 |
|
|
|
|
|
|
|
| 7 |
st.sidebar.markdown(
|
| 8 |
"""
|
| 9 |
<style>
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding: utf-8
|
| 3 |
|
|
|
|
| 4 |
import streamlit as st
|
| 5 |
|
| 6 |
+
from .backend import ServiceError, get_images_from_backend
|
| 7 |
+
|
| 8 |
st.sidebar.markdown(
|
| 9 |
"""
|
| 10 |
<style>
|
app/streamlit/backend.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
import requests
|
| 2 |
-
from io import BytesIO
|
| 3 |
import base64
|
|
|
|
|
|
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import base64
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
from PIL import Image
|
| 6 |
|
| 7 |
|
dalle_mini/data.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
from dataclasses import dataclass, field
|
| 2 |
-
from datasets import load_dataset, Dataset
|
| 3 |
from functools import partial
|
| 4 |
-
|
| 5 |
import jax
|
| 6 |
import jax.numpy as jnp
|
|
|
|
|
|
|
| 7 |
from flax.training.common_utils import shard
|
|
|
|
| 8 |
from .text import TextNormalizer
|
| 9 |
|
| 10 |
|
|
|
|
| 1 |
from dataclasses import dataclass, field
|
|
|
|
| 2 |
from functools import partial
|
| 3 |
+
|
| 4 |
import jax
|
| 5 |
import jax.numpy as jnp
|
| 6 |
+
import numpy as np
|
| 7 |
+
from datasets import Dataset, load_dataset
|
| 8 |
from flax.training.common_utils import shard
|
| 9 |
+
|
| 10 |
from .text import TextNormalizer
|
| 11 |
|
| 12 |
|
dalle_mini/model.py
CHANGED
|
@@ -1,16 +1,14 @@
|
|
| 1 |
-
import jax
|
| 2 |
import flax.linen as nn
|
| 3 |
-
|
|
|
|
| 4 |
from transformers.models.bart.modeling_flax_bart import (
|
| 5 |
-
FlaxBartModule,
|
| 6 |
-
FlaxBartForConditionalGenerationModule,
|
| 7 |
-
FlaxBartForConditionalGeneration,
|
| 8 |
-
FlaxBartEncoder,
|
| 9 |
FlaxBartDecoder,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
| 11 |
|
| 12 |
-
from transformers import BartConfig
|
| 13 |
-
|
| 14 |
|
| 15 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 16 |
def setup(self):
|
|
|
|
|
|
|
| 1 |
import flax.linen as nn
|
| 2 |
+
import jax
|
| 3 |
+
from transformers import BartConfig
|
| 4 |
from transformers.models.bart.modeling_flax_bart import (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
FlaxBartDecoder,
|
| 6 |
+
FlaxBartEncoder,
|
| 7 |
+
FlaxBartForConditionalGeneration,
|
| 8 |
+
FlaxBartForConditionalGenerationModule,
|
| 9 |
+
FlaxBartModule,
|
| 10 |
)
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 14 |
def setup(self):
|
dalle_mini/text.py
CHANGED
|
@@ -2,13 +2,15 @@
|
|
| 2 |
Utilities for processing text.
|
| 3 |
"""
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
-
from unidecode import unidecode
|
| 7 |
|
| 8 |
-
import re, math, random, html
|
| 9 |
import ftfy
|
| 10 |
-
|
| 11 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 12 |
|
| 13 |
# based on wiki word occurence
|
| 14 |
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
|
|
|
| 2 |
Utilities for processing text.
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import html
|
| 6 |
+
import math
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
from pathlib import Path
|
|
|
|
| 10 |
|
|
|
|
| 11 |
import ftfy
|
|
|
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
+
from unidecode import unidecode
|
| 14 |
|
| 15 |
# based on wiki word occurence
|
| 16 |
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
tools/train/train.py
CHANGED
|
@@ -18,37 +18,31 @@ Fine-tuning the library models for seq2seq, text to image.
|
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
import
|
| 22 |
import logging
|
|
|
|
| 23 |
import sys
|
| 24 |
import time
|
| 25 |
-
from dataclasses import dataclass, field
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import Callable, Optional
|
| 28 |
-
import json
|
| 29 |
|
| 30 |
import datasets
|
| 31 |
-
from datasets import Dataset
|
| 32 |
-
from tqdm import tqdm
|
| 33 |
-
from dataclasses import asdict
|
| 34 |
-
|
| 35 |
import jax
|
| 36 |
import jax.numpy as jnp
|
| 37 |
import optax
|
| 38 |
import transformers
|
|
|
|
|
|
|
| 39 |
from flax import jax_utils, traverse_util
|
| 40 |
-
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.jax_utils import unreplicate
|
|
|
|
| 42 |
from flax.training import train_state
|
| 43 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
| 44 |
-
from
|
| 45 |
-
|
| 46 |
-
HfArgumentParser,
|
| 47 |
-
)
|
| 48 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
| 49 |
|
| 50 |
-
import wandb
|
| 51 |
-
|
| 52 |
from dalle_mini.data import Dataset
|
| 53 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
| 54 |
|
|
|
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
import json
|
| 22 |
import logging
|
| 23 |
+
import os
|
| 24 |
import sys
|
| 25 |
import time
|
| 26 |
+
from dataclasses import asdict, dataclass, field
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Callable, Optional
|
|
|
|
| 29 |
|
| 30 |
import datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
import jax
|
| 32 |
import jax.numpy as jnp
|
| 33 |
import optax
|
| 34 |
import transformers
|
| 35 |
+
import wandb
|
| 36 |
+
from datasets import Dataset
|
| 37 |
from flax import jax_utils, traverse_util
|
|
|
|
| 38 |
from flax.jax_utils import unreplicate
|
| 39 |
+
from flax.serialization import from_bytes, to_bytes
|
| 40 |
from flax.training import train_state
|
| 41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
| 42 |
+
from tqdm import tqdm
|
| 43 |
+
from transformers import AutoTokenizer, HfArgumentParser
|
|
|
|
|
|
|
| 44 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
| 45 |
|
|
|
|
|
|
|
| 46 |
from dalle_mini.data import Dataset
|
| 47 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
| 48 |
|