Spaces:
Paused
Paused
File size: 8,207 Bytes
d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f 7ae2e3b d4bf08f e78286a 7ae2e3b e78286a 7ae2e3b d4bf08f e78286a 7ae2e3b e78286a 7ae2e3b d4bf08f 7ae2e3b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | import gradio as gr
import subprocess
import torch
import os
import shutil
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from huggingface_hub import snapshot_download
# --- ส่วนจัดการ Cache: ดึงโมเดล NSFW มาวางทับ Standard Model ---
MODEL_STANDARD = "microsoft/Florence-2-base"
MODEL_NSFW = "ljnlonoljpiljm/florence-2-base-nsfw-v2"
def setup_model_cache():
""" ดาวน์โหลดโมเดล NSFW มาวางทับโฟลเดอร์ของโมเดลมาตรฐานใน Cache """
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
folder_standard = f"models--{MODEL_STANDARD.replace('/', '--')}"
folder_nsfw = f"models--{MODEL_NSFW.replace('/', '--')}"
path_standard = os.path.join(cache_dir, folder_standard)
path_nsfw = os.path.join(cache_dir, folder_nsfw)
print(f"🔍 ตรวจสอบ Cache: {folder_standard}")
if not os.path.exists(path_standard) or not os.listdir(path_standard):
print(f"⚠️ ไม่พบโมเดลมาตรฐานใน Cache หรือโฟลเดอร์ว่าง")
if os.path.exists(path_nsfw) and os.listdir(path_nsfw):
print(f"✅ พบโมเดล NSFW ใน Cache แล้ว: {folder_nsfw}")
source_path = path_nsfw
else:
print(f"🚀 กำลังดาวน์โหลดโมเดล NSFW ({MODEL_NSFW})...")
try:
snapshot_download(
repo_id=MODEL_NSFW,
local_dir=path_nsfw,
local_dir_use_symlinks=False
)
print("✅ ดาวน์โหลดโมเดล NSFW เสร็จสิ้น")
source_path = path_nsfw
except Exception as e:
print(f"❌ ดาวน์โหลดล้มเหลว: {e}")
print("💡 ใช้โมเดลมาตรฐานแทน (อาจไม่มี NSFW filter)")
return MODEL_STANDARD
if os.path.exists(path_standard):
shutil.rmtree(path_standard)
print(f"📂 กำลัง Copy ไฟล์จาก {folder_nsfw} -> {folder_standard}...")
shutil.copytree(source_path, path_standard)
print("✅ วางไฟล์ทับเสร็จสิ้น!")
else:
print(f"✅ พบโมเดลใน Cache แล้ว: {folder_standard}")
return MODEL_STANDARD
# --- เรียกฟังก์ชันจัดการ Cache ---
FINAL_MODEL_NAME = setup_model_cache()
print(f"🚀 กำลังโหลดโมเดล (ที่ถูกปรับแต่งแล้ว): {FINAL_MODEL_NAME}...")
# --- ติดตั้ง flash-attn (ถ้าจำเป็น) ---
try:
import flash_attn
print("✅ flash_attn พร้อมใช้งาน")
except ImportError:
print("⚠️ flash_attn ไม่พบ กำลังติดตั้ง...")
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# --- โหลดโมเดล ---
device = "cuda" if torch.cuda.is_available() else "cpu"
florence_model = None
florence_processor = None
try:
florence_model = AutoModelForCausalLM.from_pretrained(
FINAL_MODEL_NAME, trust_remote_code=True
).to(device).eval()
florence_processor = AutoProcessor.from_pretrained(
FINAL_MODEL_NAME, trust_remote_code=True
)
print("✅ โหลดโมเดล Florence-2 (NSFW Version) เสร็จสิ้น!")
except Exception as e:
print(f"❌ เกิดข้อผิดพลาดในการโหลดโมเดล: {e}")
florence_model = None
florence_processor = None
# --- รายการ Task ที่รองรับ (ตามที่คุณต้องการ) ---
TASK_PROMPTS = {
"Caption": "<CAPTION>",
"Detailed Caption": "<DETAILED_CAPTION>",
"More Detailed Caption": "<MORE_DETAILED_CAPTION>",
"OCR": "<OCR>",
"OCR with Region": "<OCR_WITH_REGION>",
"Object Detection": "<OD>",
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
"Region Proposal": "<REGION_PROPOSAL>",
"Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
"Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
"Region to Segmentation": "<REGION_TO_SEGMENTATION>",
"Open Vocabulary Detection": "<OPEN_VOCABULARY_DETECTION>",
"Region to Category": "<REGION_TO_CATEGORY>",
"Region to Description": "<REGION_TO_DESCRIPTION>",
}
def process_image(image, task_name):
global florence_model, florence_processor
if florence_model is None or florence_processor is None:
return "❌ โมเดลยังไม่ได้โหลดหรือเกิดข้อผิดพลาดในการเริ่มต้น"
if image is None:
return "กรุณาเลือกรูปภาพ"
try:
# แปลง Gradio Image เป็น PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != "RGB":
image = image.convert("RGB")
# ดึง Prompt จาก Task ที่เลือก
prompt = TASK_PROMPTS.get(task_name, "<CAPTION>")
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(
generated_text, task=prompt, image_size=(image.width, image.height)
)
# --- ส่วนแก้ไข: แยกค่าออกจาก Dictionary ---
# parsed_answer จะเป็น dict เช่น {'<DETAILED_CAPTION>': 'ข้อความ...'}
# เราต้องการแค่ 'ข้อความ...'
if isinstance(parsed_answer, dict):
# ดึงค่าแรก (Value) จาก Dictionary
result_text = next(iter(parsed_answer.values()))
else:
# ถ้าไม่ใช่ dict (กรณีบาง task อาจ return string โดยตรง)
result_text = str(parsed_answer)
print(f"\n✅ Task: {task_name} | Result: {result_text}")
return result_text
except Exception as e:
return f"❌ เกิดข้อผิดพลาดขณะประมวลผล: {str(e)}"
# --- สร้าง UI ---
with gr.Blocks(title="Image-to-Prompt (Florence-2)", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🖼️ Image-to-Prompt (Florence-2)
อัปโหลดรูปภาพและเลือกประเภทการวิเคราะห์เพื่อสร้าง Prompt
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="รูปภาพ", height=400)
task_dropdown = gr.Dropdown(
choices=list(TASK_PROMPTS.keys()),
value="More Detailed Caption", # เลือกเป็นค่าเริ่มต้น
label="ประเภทการวิเคราะห์"
)
btn = gr.Button("🚀 สร้างผลลัพธ์", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="ผลลัพธ์", lines=10, max_lines=20, show_copy_button=True)
btn.click(
fn=process_image,
inputs=[image_input, task_dropdown],
outputs=text_output
)
if __name__ == "__main__":
demo.launch(debug=True) |