sangche commited on
Commit
b82e01c
·
1 Parent(s): ad2d1f6

hysts/zephyr-7b

Browse files
Files changed (7) hide show
  1. LICENSE +21 -0
  2. app.py +120 -18
  3. requirements.txt +240 -4
  4. static/styles.css +0 -7
  5. style.css +11 -0
  6. templates/index.html +0 -165
  7. templates/item.html +0 -14
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py CHANGED
@@ -1,25 +1,127 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
 
6
- app = FastAPI()
 
 
7
 
8
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
9
 
10
- templates = Jinja2Templates(directory="templates")
11
 
12
- @app.get("/", response_class=HTMLResponse)
13
- async def greet_json(request: Request):
14
- return templates.TemplateResponse(
15
- request=request, name="index.html", context={}
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- @app.get("/items/{id}", response_class=HTMLResponse)
19
- async def read_item(request: Request, id: str):
20
- return templates.TemplateResponse(
21
- request=request, name="item.html", context={"id": id}
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # to run server
25
- # $ uvicorn app:app --host 0.0.0.0 --port 7860
 
1
+ #!/usr/bin/env python
 
 
 
2
 
3
+ import os
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
 
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
+ DESCRIPTION = "# Zephyr-7B beta"
13
 
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "HuggingFaceH4/zephyr-7b-beta"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+
26
+
27
+ @spaces.GPU
28
+ def generate(
29
+ message: str,
30
+ chat_history: list[dict],
31
+ system_prompt: str = "",
32
+ max_new_tokens: int = 1024,
33
+ temperature: float = 0.7,
34
+ top_p: float = 0.95,
35
+ top_k: int = 50,
36
+ repetition_penalty: float = 1.0,
37
+ ) -> Iterator[str]:
38
+ conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
39
+ conversation += [*chat_history, {"role": "user", "content": message}]
40
 
41
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
42
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
43
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
44
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
45
+ input_ids = input_ids.to(model.device)
46
+
47
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
48
+ generate_kwargs = dict(
49
+ {"input_ids": input_ids},
50
+ streamer=streamer,
51
+ max_new_tokens=max_new_tokens,
52
+ do_sample=True,
53
+ top_p=top_p,
54
+ top_k=top_k,
55
+ temperature=temperature,
56
+ num_beams=1,
57
+ repetition_penalty=repetition_penalty,
58
  )
59
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
60
+ t.start()
61
+
62
+ outputs = []
63
+ for text in streamer:
64
+ outputs.append(text)
65
+ yield "".join(outputs)
66
+
67
+
68
+ demo = gr.ChatInterface(
69
+ fn=generate,
70
+ additional_inputs=[
71
+ gr.Textbox(
72
+ label="System prompt",
73
+ lines=6,
74
+ placeholder="You are a friendly chatbot who always responds in the style of a pirate.",
75
+ ),
76
+ gr.Slider(
77
+ label="Max new tokens",
78
+ minimum=1,
79
+ maximum=MAX_MAX_NEW_TOKENS,
80
+ step=1,
81
+ value=DEFAULT_MAX_NEW_TOKENS,
82
+ ),
83
+ gr.Slider(
84
+ label="Temperature",
85
+ minimum=0.1,
86
+ maximum=4.0,
87
+ step=0.1,
88
+ value=0.7,
89
+ ),
90
+ gr.Slider(
91
+ label="Top-p (nucleus sampling)",
92
+ minimum=0.05,
93
+ maximum=1.0,
94
+ step=0.05,
95
+ value=0.95,
96
+ ),
97
+ gr.Slider(
98
+ label="Top-k",
99
+ minimum=1,
100
+ maximum=1000,
101
+ step=1,
102
+ value=50,
103
+ ),
104
+ gr.Slider(
105
+ label="Repetition penalty",
106
+ minimum=1.0,
107
+ maximum=2.0,
108
+ step=0.05,
109
+ value=1.0,
110
+ ),
111
+ ],
112
+ stop_btn=None,
113
+ examples=[
114
+ ["Hello there! How are you doing?"],
115
+ ["Can you explain briefly to me what is the Python programming language?"],
116
+ ["Explain the plot of Cinderella in a sentence."],
117
+ ["How many hours does it take a man to eat a Helicopter?"],
118
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
119
+ ],
120
+ type="messages",
121
+ description=DESCRIPTION,
122
+ css_paths="style.css",
123
+ )
124
+
125
 
126
+ if __name__ == "__main__":
127
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,4 +1,240 @@
1
- fastapi
2
- jinja2
3
- uvicorn[standard]
4
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.2.1
4
+ # via zephyr-7b (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.8.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2024.12.14
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.4.1
20
+ # via requests
21
+ click==8.1.8
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.6
28
+ # via gradio
29
+ ffmpy==0.5.0
30
+ # via gradio
31
+ filelock==3.16.1
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2024.12.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.12.0
43
+ # via
44
+ # zephyr-7b (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.5.4
47
+ # via gradio
48
+ h11==0.14.0
49
+ # via
50
+ # httpcore
51
+ # uvicorn
52
+ hf-transfer==0.1.9
53
+ # via zephyr-7b (pyproject.toml)
54
+ httpcore==1.0.7
55
+ # via httpx
56
+ httpx==0.28.1
57
+ # via
58
+ # gradio
59
+ # gradio-client
60
+ # safehttpx
61
+ # spaces
62
+ huggingface-hub==0.27.1
63
+ # via
64
+ # accelerate
65
+ # gradio
66
+ # gradio-client
67
+ # tokenizers
68
+ # transformers
69
+ idna==3.10
70
+ # via
71
+ # anyio
72
+ # httpx
73
+ # requests
74
+ jinja2==3.1.5
75
+ # via
76
+ # gradio
77
+ # torch
78
+ markdown-it-py==3.0.0
79
+ # via rich
80
+ markupsafe==2.1.5
81
+ # via
82
+ # gradio
83
+ # jinja2
84
+ mdurl==0.1.2
85
+ # via markdown-it-py
86
+ mpmath==1.3.0
87
+ # via sympy
88
+ networkx==3.4.2
89
+ # via torch
90
+ numpy==2.2.1
91
+ # via
92
+ # accelerate
93
+ # gradio
94
+ # pandas
95
+ # transformers
96
+ nvidia-cublas-cu12==12.1.3.1
97
+ # via
98
+ # nvidia-cudnn-cu12
99
+ # nvidia-cusolver-cu12
100
+ # torch
101
+ nvidia-cuda-cupti-cu12==12.1.105
102
+ # via torch
103
+ nvidia-cuda-nvrtc-cu12==12.1.105
104
+ # via torch
105
+ nvidia-cuda-runtime-cu12==12.1.105
106
+ # via torch
107
+ nvidia-cudnn-cu12==9.1.0.70
108
+ # via torch
109
+ nvidia-cufft-cu12==11.0.2.54
110
+ # via torch
111
+ nvidia-curand-cu12==10.3.2.106
112
+ # via torch
113
+ nvidia-cusolver-cu12==11.4.5.107
114
+ # via torch
115
+ nvidia-cusparse-cu12==12.1.0.106
116
+ # via
117
+ # nvidia-cusolver-cu12
118
+ # torch
119
+ nvidia-nccl-cu12==2.20.5
120
+ # via torch
121
+ nvidia-nvjitlink-cu12==12.6.85
122
+ # via
123
+ # nvidia-cusolver-cu12
124
+ # nvidia-cusparse-cu12
125
+ nvidia-nvtx-cu12==12.1.105
126
+ # via torch
127
+ orjson==3.10.14
128
+ # via gradio
129
+ packaging==24.2
130
+ # via
131
+ # accelerate
132
+ # gradio
133
+ # gradio-client
134
+ # huggingface-hub
135
+ # spaces
136
+ # transformers
137
+ pandas==2.2.3
138
+ # via gradio
139
+ pillow==11.1.0
140
+ # via gradio
141
+ psutil==5.9.8
142
+ # via
143
+ # accelerate
144
+ # spaces
145
+ pydantic==2.10.5
146
+ # via
147
+ # fastapi
148
+ # gradio
149
+ # spaces
150
+ pydantic-core==2.27.2
151
+ # via pydantic
152
+ pydub==0.25.1
153
+ # via gradio
154
+ pygments==2.19.1
155
+ # via rich
156
+ python-dateutil==2.9.0.post0
157
+ # via pandas
158
+ python-multipart==0.0.20
159
+ # via gradio
160
+ pytz==2024.2
161
+ # via pandas
162
+ pyyaml==6.0.2
163
+ # via
164
+ # accelerate
165
+ # gradio
166
+ # huggingface-hub
167
+ # transformers
168
+ regex==2024.11.6
169
+ # via transformers
170
+ requests==2.32.3
171
+ # via
172
+ # huggingface-hub
173
+ # spaces
174
+ # transformers
175
+ rich==13.9.4
176
+ # via typer
177
+ ruff==0.9.1
178
+ # via gradio
179
+ safehttpx==0.1.6
180
+ # via gradio
181
+ safetensors==0.5.2
182
+ # via
183
+ # accelerate
184
+ # transformers
185
+ semantic-version==2.10.0
186
+ # via gradio
187
+ shellingham==1.5.4
188
+ # via typer
189
+ six==1.17.0
190
+ # via python-dateutil
191
+ sniffio==1.3.1
192
+ # via anyio
193
+ spaces==0.32.0
194
+ # via zephyr-7b (pyproject.toml)
195
+ starlette==0.41.3
196
+ # via
197
+ # fastapi
198
+ # gradio
199
+ sympy==1.13.3
200
+ # via torch
201
+ tokenizers==0.21.0
202
+ # via transformers
203
+ tomlkit==0.13.2
204
+ # via gradio
205
+ torch==2.4.0
206
+ # via
207
+ # zephyr-7b (pyproject.toml)
208
+ # accelerate
209
+ tqdm==4.67.1
210
+ # via
211
+ # huggingface-hub
212
+ # transformers
213
+ transformers==4.48.0
214
+ # via zephyr-7b (pyproject.toml)
215
+ triton==3.0.0
216
+ # via torch
217
+ typer==0.15.1
218
+ # via gradio
219
+ typing-extensions==4.12.2
220
+ # via
221
+ # anyio
222
+ # fastapi
223
+ # gradio
224
+ # gradio-client
225
+ # huggingface-hub
226
+ # pydantic
227
+ # pydantic-core
228
+ # rich
229
+ # spaces
230
+ # torch
231
+ # typer
232
+ # uvicorn
233
+ tzdata==2024.2
234
+ # via pandas
235
+ urllib3==2.3.0
236
+ # via requests
237
+ uvicorn==0.34.0
238
+ # via gradio
239
+ websockets==14.1
240
+ # via gradio-client
static/styles.css DELETED
@@ -1,7 +0,0 @@
1
- h1 a {
2
- color: red;
3
- }
4
-
5
- h2 {
6
- color: red;
7
- }
 
 
 
 
 
 
 
 
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: white;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
templates/index.html DELETED
@@ -1,165 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>Compliment Bot 💖</title>
7
- <!-- <script src="https://cdn.jsdelivr.net/npm/@gradio/client@1.2.0/dist/index.min.js"></script> -->
8
- <style>
9
- body {
10
- font-family: Arial, sans-serif;
11
- display: flex;
12
- justify-content: center;
13
- align-items: center;
14
- height: 100vh;
15
- margin: 0;
16
- background-color: #f0f0f0;
17
- }
18
- .container {
19
- text-align: center;
20
- background-color: white;
21
- padding: 20px;
22
- border-radius: 10px;
23
- box-shadow: 0 0 10px rgba(0,0,0,0.1);
24
- max-width: 400px;
25
- width: 100%;
26
- }
27
- #headshot {
28
- max-width: 300px;
29
- max-height: 300px;
30
- margin: 20px auto;
31
- display: block;
32
- }
33
- #compliment {
34
- font-size: 18px;
35
- font-weight: bold;
36
- color: #4a4a4a;
37
- min-height: 50px;
38
- }
39
- .loader {
40
- border: 5px solid #f3f3f3;
41
- border-top: 5px solid #3498db;
42
- border-radius: 50%;
43
- width: 30px;
44
- height: 30px;
45
- animation: spin 1s linear infinite;
46
- margin: 20px auto;
47
- display: none;
48
- }
49
- @keyframes spin {
50
- 0% { transform: rotate(0deg); }
51
- 100% { transform: rotate(360deg); }
52
- }
53
- #uploadButton {
54
- background-color: #4CAF50;
55
- border: none;
56
- color: white;
57
- padding: 10px 20px;
58
- text-align: center;
59
- text-decoration: none;
60
- display: inline-block;
61
- font-size: 16px;
62
- margin: 4px 2px;
63
- cursor: pointer;
64
- border-radius: 5px;
65
- }
66
- </style>
67
- </head>
68
- <body>
69
- <div class="container">
70
- <h1> Compliment Bot 💖</h1>
71
- <input type="file" id="fileInput" accept="image/*" style="display: none;">
72
- <button id="uploadButton">Upload New Headshot</button>
73
- <br><br>
74
- <img id="headshot" src="" alt="Your headshot" style="display:none;">
75
- <div class="loader" id="loader"></div>
76
- <p id="compliment"></p>
77
- </div>
78
-
79
- <script>
80
- const fileInput = document.getElementById('fileInput');
81
- const uploadButton = document.getElementById('uploadButton');
82
- const headshot = document.getElementById('headshot');
83
- const compliment = document.getElementById('compliment');
84
- const loader = document.getElementById('loader');
85
-
86
- const SYSTEM_PROMPT = `
87
- You are helpful assistant that gives the best compliments to people.
88
- You will be given a caption of someone's headshot.
89
- Based on that caption, provide a one sentence compliment to the person in the image.
90
- Make sure you compliment the person in the image and not any objects or scenery.
91
- Do NOT include any hashtags in your compliment or phrases like (emojis: dog, smiling face with heart-eyes, sun).
92
-
93
- Here are some examples of the desired behavior:
94
-
95
- Caption: a front view of a man who is smiling, there is a lighthouse in the background, there is a grassy area on the left that is green and curved. in the distance you can see the ocean and the shore. there is a grey and cloudy sky above the lighthouse and the trees.
96
- Compliment: Your smile is as bright as a lighthouse, lighting up the world around you. 🌟
97
-
98
- Caption: in a close-up, a blonde woman with short, wavy hair, is the focal point of the image. she's dressed in a dark brown turtleneck sweater, paired with a black hat and a black suit jacket. her lips are a vibrant red, and her eyes are a deep brown. in the background, a man with a black hat and a white shirt is visible.
99
- Compliment: You are the epitome of elegance and grace, with a style that is as timeless as your beauty. 💃🎩
100
-
101
- Conversation begins below:
102
-
103
- `
104
-
105
- uploadButton.addEventListener('click', function() {
106
- fileInput.click();
107
- });
108
-
109
- fileInput.addEventListener('change', function(e) {
110
- const file = e.target.files[0];
111
- if (file) {
112
- const reader = new FileReader();
113
- reader.onload = function(event) {
114
- headshot.src = event.target.result;
115
- headshot.style.display = 'block';
116
- generateCompliment(file);
117
- }
118
- reader.readAsDataURL(file);
119
- }
120
- });
121
-
122
- async function generateCompliment(file) {
123
- compliment.textContent = '';
124
- loader.style.display = 'block';
125
-
126
- try {
127
- const client_lib = await import("https://cdn.jsdelivr.net/npm/@gradio/client@1.2.0/dist/index.min.js");
128
- const Client = client_lib.Client;
129
- const handle_file = client_lib.handle_file;
130
- const captioning_space = await Client.connect("gokaygokay/SD3-Long-Captioner");
131
- const llm_space = await Client.connect("hysts/zephyr-7b");
132
-
133
- const caption = await captioning_space.predict("/create_captions_rich", { image: file });
134
-
135
-
136
- console.info("Caption", caption.data);
137
-
138
- const submission = llm_space.submit("/chat", {
139
- system_prompt: SYSTEM_PROMPT,
140
- message: `Caption: ${caption.data}\nCompliment: `,
141
- max_new_tokens: 256,
142
- temperature: 0.7,
143
- top_p: 0.95,
144
- top_k: 50,
145
- repetition_penalty: 1,
146
- }
147
- )
148
-
149
- for await (const msg of submission) {
150
- loader.style.display = 'none';
151
- if (msg.type === "data") {
152
- console.log("msg.data", msg.data);
153
- compliment.textContent = msg.data[0]
154
- }
155
- }
156
- } catch (error) {
157
- console.error('Error:', error);
158
- loader.style.display = 'none';
159
- compliment.textContent = "Oops! We couldn't generate a compliment. You're still awesome though!"
160
- }
161
-
162
- }
163
- </script>
164
- </body>
165
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
templates/item.html DELETED
@@ -1,14 +0,0 @@
1
- <html>
2
-
3
- <head>
4
- <title>Item Details</title>
5
- <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
6
- </head>
7
-
8
- <body>
9
- <h1>Hello <a href="{{ url_for('read_item', id=id) }}">Item ID: {{ id }}</a></h1>
10
-
11
- <h2>Hello...</h2>
12
- </body>
13
-
14
- </html>