Image-to-Prompt / app.py
R1000's picture
Update app.py
e78286a verified
import gradio as gr
import subprocess
import torch
import os
import shutil
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from huggingface_hub import snapshot_download
# --- ส่วนจัดการ Cache: ดึงโมเดล NSFW มาวางทับ Standard Model ---
MODEL_STANDARD = "microsoft/Florence-2-base"
MODEL_NSFW = "ljnlonoljpiljm/florence-2-base-nsfw-v2"
def setup_model_cache():
""" ดาวน์โหลดโมเดล NSFW มาวางทับโฟลเดอร์ของโมเดลมาตรฐานใน Cache """
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
folder_standard = f"models--{MODEL_STANDARD.replace('/', '--')}"
folder_nsfw = f"models--{MODEL_NSFW.replace('/', '--')}"
path_standard = os.path.join(cache_dir, folder_standard)
path_nsfw = os.path.join(cache_dir, folder_nsfw)
print(f"🔍 ตรวจสอบ Cache: {folder_standard}")
if not os.path.exists(path_standard) or not os.listdir(path_standard):
print(f"⚠️ ไม่พบโมเดลมาตรฐานใน Cache หรือโฟลเดอร์ว่าง")
if os.path.exists(path_nsfw) and os.listdir(path_nsfw):
print(f"✅ พบโมเดล NSFW ใน Cache แล้ว: {folder_nsfw}")
source_path = path_nsfw
else:
print(f"🚀 กำลังดาวน์โหลดโมเดล NSFW ({MODEL_NSFW})...")
try:
snapshot_download(
repo_id=MODEL_NSFW,
local_dir=path_nsfw,
local_dir_use_symlinks=False
)
print("✅ ดาวน์โหลดโมเดล NSFW เสร็จสิ้น")
source_path = path_nsfw
except Exception as e:
print(f"❌ ดาวน์โหลดล้มเหลว: {e}")
print("💡 ใช้โมเดลมาตรฐานแทน (อาจไม่มี NSFW filter)")
return MODEL_STANDARD
if os.path.exists(path_standard):
shutil.rmtree(path_standard)
print(f"📂 กำลัง Copy ไฟล์จาก {folder_nsfw} -> {folder_standard}...")
shutil.copytree(source_path, path_standard)
print("✅ วางไฟล์ทับเสร็จสิ้น!")
else:
print(f"✅ พบโมเดลใน Cache แล้ว: {folder_standard}")
return MODEL_STANDARD
# --- เรียกฟังก์ชันจัดการ Cache ---
FINAL_MODEL_NAME = setup_model_cache()
print(f"🚀 กำลังโหลดโมเดล (ที่ถูกปรับแต่งแล้ว): {FINAL_MODEL_NAME}...")
# --- ติดตั้ง flash-attn (ถ้าจำเป็น) ---
try:
import flash_attn
print("✅ flash_attn พร้อมใช้งาน")
except ImportError:
print("⚠️ flash_attn ไม่พบ กำลังติดตั้ง...")
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# --- โหลดโมเดล ---
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = None
florence_processor = None
try:
florence_model = AutoModelForCausalLM.from_pretrained(
FINAL_MODEL_NAME, trust_remote_code=True
).to(device).eval()
florence_processor = AutoProcessor.from_pretrained(
FINAL_MODEL_NAME, trust_remote_code=True
)
print("✅ โหลดโมเดล Florence-2 (NSFW Version) เสร็จสิ้น!")
except Exception as e:
print(f"❌ เกิดข้อผิดพลาดในการโหลดโมเดล: {e}")
florence_model = None
florence_processor = None
# --- รายการ Task ที่รองรับ (ตามที่คุณต้องการ) ---
TASK_PROMPTS = {
"Caption": "<CAPTION>",
"Detailed Caption": "<DETAILED_CAPTION>",
"More Detailed Caption": "<MORE_DETAILED_CAPTION>",
"OCR": "<OCR>",
"OCR with Region": "<OCR_WITH_REGION>",
"Object Detection": "<OD>",
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
"Region Proposal": "<REGION_PROPOSAL>",
"Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
"Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
"Region to Segmentation": "<REGION_TO_SEGMENTATION>",
"Open Vocabulary Detection": "<OPEN_VOCABULARY_DETECTION>",
"Region to Category": "<REGION_TO_CATEGORY>",
"Region to Description": "<REGION_TO_DESCRIPTION>",
}
def process_image(image, task_name):
global florence_model, florence_processor
if florence_model is None or florence_processor is None:
return "❌ โมเดลยังไม่ได้โหลดหรือเกิดข้อผิดพลาดในการเริ่มต้น"
if image is None:
return "กรุณาเลือกรูปภาพ"
try:
# แปลง Gradio Image เป็น PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != "RGB":
image = image.convert("RGB")
# ดึง Prompt จาก Task ที่เลือก
prompt = TASK_PROMPTS.get(task_name, "<CAPTION>")
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text, task=prompt, image_size=(image.width, image.height)
)
# --- ส่วนแก้ไข: แยกค่าออกจาก Dictionary ---
# parsed_answer จะเป็น dict เช่น {'<DETAILED_CAPTION>': 'ข้อความ...'}
# เราต้องการแค่ 'ข้อความ...'
if isinstance(parsed_answer, dict):
# ดึงค่าแรก (Value) จาก Dictionary
result_text = next(iter(parsed_answer.values()))
else:
# ถ้าไม่ใช่ dict (กรณีบาง task อาจ return string โดยตรง)
result_text = str(parsed_answer)
print(f"\n✅ Task: {task_name} | Result: {result_text}")
return result_text
except Exception as e:
return f"❌ เกิดข้อผิดพลาดขณะประมวลผล: {str(e)}"
# --- สร้าง UI ---
with gr.Blocks(title="Image-to-Prompt (Florence-2)", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🖼️ Image-to-Prompt (Florence-2)
อัปโหลดรูปภาพและเลือกประเภทการวิเคราะห์เพื่อสร้าง Prompt
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="รูปภาพ", height=400)
task_dropdown = gr.Dropdown(
choices=list(TASK_PROMPTS.keys()),
value="More Detailed Caption", # เลือกเป็นค่าเริ่มต้น
label="ประเภทการวิเคราะห์"
)
btn = gr.Button("🚀 สร้างผลลัพธ์", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="ผลลัพธ์", lines=10, max_lines=20, show_copy_button=True)
btn.click(
fn=process_image,
inputs=[image_input, task_dropdown],
outputs=text_output
)
if __name__ == "__main__":
demo.launch(debug=True)