Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import torch
|
| 3 |
from flashsloth.constants import (
|
| 4 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
|
|
@@ -14,28 +15,39 @@ from flashsloth.mm_utils import (
|
|
| 14 |
from PIL import Image
|
| 15 |
import gradio as gr
|
| 16 |
|
| 17 |
-
|
| 18 |
from transformers import TextIteratorStreamer
|
| 19 |
from threading import Thread
|
| 20 |
|
| 21 |
-
|
| 22 |
disable_torch_init()
|
| 23 |
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
model.to('cuda')
|
| 29 |
-
model.eval()
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
keywords = ['</s>']
|
| 34 |
|
|
|
|
| 35 |
|
| 36 |
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
|
| 37 |
text = text + LEARNABLE_TOKEN
|
| 38 |
-
|
| 39 |
|
| 40 |
image = image.convert('RGB')
|
| 41 |
if model.config.image_hd:
|
|
@@ -43,14 +55,12 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
|
|
| 43 |
else:
|
| 44 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
| 45 |
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
|
| 46 |
-
|
| 47 |
|
| 48 |
conv = conv_templates["phi2"].copy()
|
| 49 |
conv.append_message(conv.roles[0], text)
|
| 50 |
conv.append_message(conv.roles[1], None)
|
| 51 |
prompt = conv.get_prompt()
|
| 52 |
|
| 53 |
-
|
| 54 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
| 55 |
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
|
| 56 |
|
|
@@ -79,11 +89,9 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
|
|
| 79 |
with torch.inference_mode():
|
| 80 |
model.generate(**generation_kwargs)
|
| 81 |
|
| 82 |
-
# ๅจๅ็ฌ็บฟ็จไธญ่ฟ่ก็ๆ๏ผ้ฒๆญข้ปๅก
|
| 83 |
generation_thread = Thread(target=_generate)
|
| 84 |
generation_thread.start()
|
| 85 |
|
| 86 |
-
# ่พน็ๆ่พนyield่พๅบ
|
| 87 |
partial_text = ""
|
| 88 |
for new_text in streamer:
|
| 89 |
partial_text += new_text
|
|
@@ -91,7 +99,6 @@ def generate_description(image, prompt_text, temperature, top_p, max_tokens):
|
|
| 91 |
|
| 92 |
generation_thread.join()
|
| 93 |
|
| 94 |
-
# ่ชๅฎไนCSSๆ ทๅผ๏ผ็จไบๅขๅคงๅญไฝๅ็พๅ็้ข
|
| 95 |
custom_css = """
|
| 96 |
<style>
|
| 97 |
/* ๅขๅคงๆ ้ขๅญไฝ */
|
|
@@ -152,10 +159,17 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 152 |
minimum=64,
|
| 153 |
maximum=3072,
|
| 154 |
step=1,
|
| 155 |
-
value=
|
| 156 |
label="Max Tokens"
|
| 157 |
)
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
with gr.Column(scale=1):
|
| 160 |
prompt_input = gr.Textbox(
|
| 161 |
lines=3,
|
|
@@ -173,10 +187,10 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 173 |
|
| 174 |
submit_button.click(
|
| 175 |
fn=generate_description,
|
| 176 |
-
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider],
|
| 177 |
outputs=output_text,
|
| 178 |
show_progress=True
|
| 179 |
)
|
| 180 |
|
| 181 |
if __name__ == "__main__":
|
| 182 |
-
demo.queue().launch()
|
|
|
|
| 1 |
import os
|
| 2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
| 3 |
import torch
|
| 4 |
from flashsloth.constants import (
|
| 5 |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
|
|
|
|
| 15 |
from PIL import Image
|
| 16 |
import gradio as gr
|
| 17 |
|
|
|
|
| 18 |
from transformers import TextIteratorStreamer
|
| 19 |
from threading import Thread
|
| 20 |
|
|
|
|
| 21 |
disable_torch_init()
|
| 22 |
|
| 23 |
+
MODEL_PATH_HD = "Tongbo/FlashSloth_HD-3.2B"
|
| 24 |
+
MODEL_PATH_NEW = "Tongbo/FlashSloth-3.2B"
|
| 25 |
|
| 26 |
+
model_name_hd = get_model_name_from_path(MODEL_PATH_HD)
|
| 27 |
+
model_name_new = get_model_name_from_path(MODEL_PATH_NEW)
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
models = {
|
| 30 |
+
"FlashSloth HD": load_pretrained_model(MODEL_PATH_HD, None, model_name_hd),
|
| 31 |
+
"FlashSloth": load_pretrained_model(MODEL_PATH_NEW, None, model_name_new)
|
| 32 |
+
}
|
| 33 |
|
| 34 |
+
for key in models:
|
| 35 |
+
tokenizer, model, image_processor, context_len = models[key]
|
| 36 |
+
model.to('cuda')
|
| 37 |
+
model.eval()
|
| 38 |
+
|
| 39 |
+
def generate_description(image, prompt_text, temperature, top_p, max_tokens, selected_model):
|
| 40 |
+
"""
|
| 41 |
+
็ๆๅพ็ๆ่ฟฐ็ๅฝๆฐ๏ผๆฏๆๆตๅผ่พๅบ๏ผๅนถๆ นๆฎ้ๆฉ็ๆจกๅ่ฟ่กๅค็ใ
|
| 42 |
+
ๆฐๅขๅๆฐ:
|
| 43 |
+
- selected_model: ็จๆท้ๆฉ็ๆจกๅๅ็งฐ
|
| 44 |
+
"""
|
| 45 |
keywords = ['</s>']
|
| 46 |
|
| 47 |
+
tokenizer, model, image_processor, context_len = models[selected_model]
|
| 48 |
|
| 49 |
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
|
| 50 |
text = text + LEARNABLE_TOKEN
|
|
|
|
| 51 |
|
| 52 |
image = image.convert('RGB')
|
| 53 |
if model.config.image_hd:
|
|
|
|
| 55 |
else:
|
| 56 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
| 57 |
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
|
|
|
|
| 58 |
|
| 59 |
conv = conv_templates["phi2"].copy()
|
| 60 |
conv.append_message(conv.roles[0], text)
|
| 61 |
conv.append_message(conv.roles[1], None)
|
| 62 |
prompt = conv.get_prompt()
|
| 63 |
|
|
|
|
| 64 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
| 65 |
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
|
| 66 |
|
|
|
|
| 89 |
with torch.inference_mode():
|
| 90 |
model.generate(**generation_kwargs)
|
| 91 |
|
|
|
|
| 92 |
generation_thread = Thread(target=_generate)
|
| 93 |
generation_thread.start()
|
| 94 |
|
|
|
|
| 95 |
partial_text = ""
|
| 96 |
for new_text in streamer:
|
| 97 |
partial_text += new_text
|
|
|
|
| 99 |
|
| 100 |
generation_thread.join()
|
| 101 |
|
|
|
|
| 102 |
custom_css = """
|
| 103 |
<style>
|
| 104 |
/* ๅขๅคงๆ ้ขๅญไฝ */
|
|
|
|
| 159 |
minimum=64,
|
| 160 |
maximum=3072,
|
| 161 |
step=1,
|
| 162 |
+
value=3072,
|
| 163 |
label="Max Tokens"
|
| 164 |
)
|
| 165 |
|
| 166 |
+
model_dropdown = gr.Dropdown(
|
| 167 |
+
choices=list(models.keys()),
|
| 168 |
+
value=list(models.keys())[0],
|
| 169 |
+
label="้ๆฉๆจกๅ",
|
| 170 |
+
type="value"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
with gr.Column(scale=1):
|
| 174 |
prompt_input = gr.Textbox(
|
| 175 |
lines=3,
|
|
|
|
| 187 |
|
| 188 |
submit_button.click(
|
| 189 |
fn=generate_description,
|
| 190 |
+
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider, model_dropdown],
|
| 191 |
outputs=output_text,
|
| 192 |
show_progress=True
|
| 193 |
)
|
| 194 |
|
| 195 |
if __name__ == "__main__":
|
| 196 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=8888)
|