pszemraj joaogante commited on
Commit
a0f76b5
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: joaogante <joaogante@users.noreply.huggingface.co>

Files changed (5) hide show
  1. .gitattributes +34 -0
  2. .gitignore +169 -0
  3. README.md +16 -0
  4. app.py +168 -0
  5. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+
168
+ # ruff
169
+ .ruff_cache
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: text2text Instruct playground
3
+ emoji: 🗨️⌛
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: joaogante/transformers_streaming
11
+ license: apache-2.0
12
+ ---
13
+
14
+ waddup
15
+
16
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ import logging
3
+ import time
4
+
5
+ logging.basicConfig(
6
+ level=logging.INFO,
7
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
8
+ )
9
+
10
+ import torch
11
+ import gradio as gr
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
13
+
14
+ model_id = "pszemraj/tFINE-850m-24x24-v0.5-instruct-L1"
15
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ logging.info(f"Running on device:\t {torch_device}")
17
+ logging.info(f"CPU threads:\t {torch.get_num_threads()}")
18
+
19
+
20
+ if torch_device == "cuda":
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(
22
+ model_id, load_in_8bit=True, device_map="auto"
23
+ )
24
+ else:
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
26
+ try:
27
+ model = torch.compile(model)
28
+ except Exception as e:
29
+ logging.error(f"Unable to compile model:\t{e}")
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
32
+
33
+
34
+ def run_generation(
35
+ user_text,
36
+ top_p,
37
+ temperature,
38
+ top_k,
39
+ max_new_tokens,
40
+ repetition_penalty=1.1,
41
+ length_penalty=1.0,
42
+ no_repeat_ngram_size=4,
43
+ use_generation_config=False,
44
+ ):
45
+ st = time.perf_counter()
46
+ # Get the model and tokenizer, and tokenize the user text.
47
+ model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
48
+
49
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
50
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
51
+ streamer = TextIteratorStreamer(
52
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
53
+ )
54
+ generate_kwargs = dict(
55
+ model_inputs,
56
+ streamer=streamer,
57
+ max_new_tokens=max_new_tokens,
58
+ do_sample=True,
59
+ num_beams=1,
60
+ top_p=top_p,
61
+ temperature=float(temperature),
62
+ top_k=top_k,
63
+ repetition_penalty=repetition_penalty,
64
+ length_penalty=length_penalty,
65
+ no_repeat_ngram_size=no_repeat_ngram_size,
66
+ renormalize_logits=True,
67
+ )
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
69
+ t.start()
70
+
71
+ # Pull the generated text from the streamer, and update the model output.
72
+ model_output = ""
73
+ for new_text in streamer:
74
+ model_output += new_text
75
+ yield model_output
76
+ logging.info("Total rt:\t{rt} sec".format(rt=round(time.perf_counter() - st, 3)))
77
+ return model_output
78
+
79
+
80
+ def reset_textbox():
81
+ return gr.update(value="")
82
+
83
+
84
+ with gr.Blocks() as demo:
85
+ duplicate_link = (
86
+ "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
87
+ )
88
+ gr.Markdown(
89
+ "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
90
+ "This demo showcases the use of the "
91
+ "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
92
+ "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
93
+ f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
94
+ f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
95
+ "template! 💛"
96
+ )
97
+ gr.Markdown("---")
98
+ with gr.Row():
99
+ with gr.Column(scale=4):
100
+ user_text = gr.Textbox(
101
+ value="How to become a polar bear tamer?",
102
+ label="User input",
103
+ )
104
+ model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
105
+ button_submit = gr.Button(value="Submit", variant="primary")
106
+
107
+ with gr.Column(scale=1):
108
+ max_new_tokens = gr.Slider(
109
+ minimum=32,
110
+ maximum=1024,
111
+ value=256,
112
+ step=32,
113
+ interactive=True,
114
+ label="Max New Tokens",
115
+ )
116
+ top_p = gr.Slider(
117
+ minimum=0.05,
118
+ maximum=1.0,
119
+ value=0.95,
120
+ step=0.05,
121
+ interactive=True,
122
+ label="Top-p (nucleus sampling)",
123
+ )
124
+ top_k = gr.Slider(
125
+ minimum=1,
126
+ maximum=50,
127
+ value=50,
128
+ step=1,
129
+ interactive=True,
130
+ label="Top-k",
131
+ )
132
+ temperature = gr.Slider(
133
+ minimum=0.1,
134
+ maximum=1.4,
135
+ value=0.3,
136
+ step=0.05,
137
+ interactive=True,
138
+ label="Temperature",
139
+ )
140
+ repetition_penalty = gr.Slider(
141
+ minimum=0.9,
142
+ maximum=2.5,
143
+ value=1.1,
144
+ step=0.1,
145
+ interactive=True,
146
+ label="Repetition Penalty",
147
+ )
148
+ length_penalty = gr.Slider(
149
+ minimum=0.8,
150
+ maximum=1.5,
151
+ value=1.0,
152
+ step=0.1,
153
+ interactive=True,
154
+ label="Length Penalty",
155
+ )
156
+
157
+ user_text.submit(
158
+ run_generation,
159
+ [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
160
+ model_output,
161
+ )
162
+ button_submit.click(
163
+ run_generation,
164
+ [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
165
+ model_output,
166
+ )
167
+
168
+ demo.queue(max_size=10).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes
3
+ torch
4
+ transformers
5
+ gradio
6
+ sentencepiece