post / app.py
vivaceailab's picture
Update app.py
f92b7ee verified
# app.py
# app.py
import gradio as gr
import subprocess
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
# โš™๏ธ flashโ€attn ์„ค์น˜ (CUDA ๋นŒ๋“œ๋ฅผ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค)
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
# 1. ์žฅ์น˜ ์„ค์ •
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2. Florence ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
florence_model = AutoModelForCausalLM.from_pretrained(
'microsoft/Florence-2-base',
trust_remote_code=True
).to(device).eval()
florence_processor = AutoProcessor.from_pretrained(
'microsoft/Florence-2-base',
trust_remote_code=True
)
# 3. ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_caption(image):
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# 30~50๋‹จ์–ด ๋ถ„๋Ÿ‰์˜ ํ•œ๊ตญ์–ด ์ƒ์„ธ ์„ค๋ช…์„ ์ƒ์„ฑํ•˜๋ผ๋Š” ์ง€์‹œ๋ฌธ
instruction = (
"์ด ์ด๋ฏธ์ง€๋ฅผ 30์—์„œ 50๋‹จ์–ด ๋ถ„๋Ÿ‰์˜ ํ•œ๊ตญ์–ด๋กœ ์ƒ์„ธํžˆ ์„ค๋ช…ํ•˜์„ธ์š”. "
"๋ฐฐ๊ฒฝ, ์ƒ‰์ƒ, ์งˆ๊ฐ, ์ธ๋ฌผ์˜ ํ‘œ์ •๊ณผ ์˜์ƒ, ์กฐ๋ช…, ๊ตฌ๋„, ๋ถ„์œ„๊ธฐ ๋“ฑ์„ ๋ชจ๋‘ ํฌํ•จํ•˜์—ฌ ์„œ์ˆ ํ•ด ์ฃผ์„ธ์š”."
)
inputs = florence_processor(
text=instruction,
images=image,
return_tensors="pt"
).to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
early_stopping=False,
)
generated_text = florence_processor.batch_decode(
generated_ids,
skip_special_tokens=False
)[0]
parsed = florence_processor.post_process_generation(
generated_text,
task=instruction,
image_size=(image.width, image.height)
)
prompt = parsed[instruction]
# ํ•„์š”์‹œ "Asian"โ†’"Korean" ๊ต์ •
if "Asian" in prompt:
prompt = prompt.replace("Asian", "Korean")
print("โœ… ์ƒ์„ฑ ์™„๋ฃŒ:\n", prompt)
return prompt
# 4. Gradio ๋ธ”๋ก์œผ๋กœ ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ (์บ๋ฆฌ์ปค์ณ ๋ฒ„ํŠผ ์œ ์ง€)
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
gr.Markdown("## ๐Ÿ–ผ๏ธ ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ๊ธฐ")
gr.Markdown(
"โš  ํ˜„์žฌ CPU ๋ชจ๋“œ๋กœ ์‹คํ–‰ ์ค‘์ด๋ฏ€๋กœ ์†๋„๊ฐ€ ๋А๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์–‘ํ•ด ๋ถ€ํƒ๋“œ๋ฆฝ๋‹ˆ๋‹ค."
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="์ž…๋ ฅ ์ด๋ฏธ์ง€", type="pil")
with gr.Column():
# โ‡จ lines๋ฅผ 3์—์„œ 6์œผ๋กœ ๋Š˜๋ ค ํ…์ŠคํŠธ ๋ฐ•์Šค ๋†’์ด๋ฅผ 2๋ฐฐ๋กœ ํ‚ค์›€
caption_output = gr.Textbox(
label="์ƒ์„ฑ๋œ ์„ค๋ช…",
lines=6,
show_copy_button=True
)
# ์˜ค๋ฅธ์ชฝ ํ•˜๋‹จ '์บ๋ฆฌ์ปค์ณ ๋งŒ๋“ค๊ธฐ' ๋ฒ„ํŠผ
gr.HTML("""
<div style='margin-top: 10px; text-align: center;'>
<a href="https://huggingface.co/spaces/VIDraft/stable-diffusion-3.5-large-turboX" target="_blank">
<button style='
padding: 10px 20px;
background-color: #ff9900;
color: white;
border: none;
border-radius: 10px;
font-size: 16px;
box-shadow: 2px 2px 8px rgba(0,0,0,0.3);
cursor: pointer;
'>
๐ŸŽจ ์บ๋ฆฌ์ปค์ณ ๋งŒ๋“ค๊ธฐ
</button>
</a>
</div>
""")
# ์—…๋กœ๋“œํ•˜๋ฉด ์ž๋™์œผ๋กœ generate_caption ํ˜ธ์ถœ
image_input.upload(
fn=generate_caption,
inputs=image_input,
outputs=caption_output
)
# 5. ์›น์•ฑ ์‹คํ–‰
if __name__ == "__main__":
demo.launch(debug=True)
# import gradio as gr
# import torch
# from PIL import Image
# from transformers import BlipProcessor, BlipForConditionalGeneration
# # 1. ์žฅ์น˜ ์„ค์ •
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # 2. ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
# processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
# # 3. ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ ํ•จ์ˆ˜
# def generate_caption(image):
# if image is None:
# return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
# # ๊ณ ์† ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ๋ฆฌ์‚ฌ์ด์ฆˆ
# image = image.resize((384, 384))
# # ์„ค๋ช… ์ƒ์„ฑ
# inputs = processor(images=image, return_tensors="pt").to(device)
# output_ids = model.generate(**inputs, max_length=50)
# caption = processor.decode(output_ids[0], skip_special_tokens=True)
# print("โœ… ์ƒ์„ฑ๋œ ์„ค๋ช…:", caption)
# return caption
# # 4. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
# with gr.Blocks(title="์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ๊ธฐ") as demo:
# gr.Markdown("## ๐Ÿ–ผ๏ธ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ์„ค๋ช…์ด ์ž๋™ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")
# with gr.Row():
# with gr.Column():
# image_input = gr.Image(label="์ž…๋ ฅ ์ด๋ฏธ์ง€", type="pil")
# with gr.Column():
# caption_output = gr.Textbox(label="์ƒ์„ฑ๋œ ์„ค๋ช…", lines=3, show_copy_button=True)
# # HTML๋กœ ๋ฒ„ํŠผ ์ƒ์„ฑ
# gr.HTML("""
# <div style='margin-top: 10px; text-align: center;'>
# <a href="https://huggingface.co/spaces/VIDraft/stable-diffusion-3.5-large-turboX" target="_blank">
# <button style='padding: 10px 20px; background-color: #ff9900; color: white; border: none; border-radius: 10px; font-size: 16px; box-shadow: 2px 2px 8px rgba(0,0,0,0.3); cursor: pointer;'>
# ๐ŸŽจ ์บ๋ฆฌ์ปค์ณ ๋งŒ๋“ค๊ธฐ
# </button>
# </a>
# </div>
# """)
# # ์—…๋กœ๋“œ โ†’ ์„ค๋ช… ์ž๋™ ์ƒ์„ฑ ์—ฐ๊ฒฐ
# image_input.upload(fn=generate_caption, inputs=image_input, outputs=caption_output)
# # 5. ์•ฑ ์‹คํ–‰
# demo.launch(debug=True)