Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import torch | |
| import os | |
| import shutil | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from huggingface_hub import snapshot_download | |
| # --- ส่วนจัดการ Cache: ดึงโมเดล NSFW มาวางทับ Standard Model --- | |
| MODEL_STANDARD = "microsoft/Florence-2-base" | |
| MODEL_NSFW = "ljnlonoljpiljm/florence-2-base-nsfw-v2" | |
| def setup_model_cache(): | |
| """ ดาวน์โหลดโมเดล NSFW มาวางทับโฟลเดอร์ของโมเดลมาตรฐานใน Cache """ | |
| cache_dir = os.path.expanduser("~/.cache/huggingface/hub") | |
| folder_standard = f"models--{MODEL_STANDARD.replace('/', '--')}" | |
| folder_nsfw = f"models--{MODEL_NSFW.replace('/', '--')}" | |
| path_standard = os.path.join(cache_dir, folder_standard) | |
| path_nsfw = os.path.join(cache_dir, folder_nsfw) | |
| print(f"🔍 ตรวจสอบ Cache: {folder_standard}") | |
| if not os.path.exists(path_standard) or not os.listdir(path_standard): | |
| print(f"⚠️ ไม่พบโมเดลมาตรฐานใน Cache หรือโฟลเดอร์ว่าง") | |
| if os.path.exists(path_nsfw) and os.listdir(path_nsfw): | |
| print(f"✅ พบโมเดล NSFW ใน Cache แล้ว: {folder_nsfw}") | |
| source_path = path_nsfw | |
| else: | |
| print(f"🚀 กำลังดาวน์โหลดโมเดล NSFW ({MODEL_NSFW})...") | |
| try: | |
| snapshot_download( | |
| repo_id=MODEL_NSFW, | |
| local_dir=path_nsfw, | |
| local_dir_use_symlinks=False | |
| ) | |
| print("✅ ดาวน์โหลดโมเดล NSFW เสร็จสิ้น") | |
| source_path = path_nsfw | |
| except Exception as e: | |
| print(f"❌ ดาวน์โหลดล้มเหลว: {e}") | |
| print("💡 ใช้โมเดลมาตรฐานแทน (อาจไม่มี NSFW filter)") | |
| return MODEL_STANDARD | |
| if os.path.exists(path_standard): | |
| shutil.rmtree(path_standard) | |
| print(f"📂 กำลัง Copy ไฟล์จาก {folder_nsfw} -> {folder_standard}...") | |
| shutil.copytree(source_path, path_standard) | |
| print("✅ วางไฟล์ทับเสร็จสิ้น!") | |
| else: | |
| print(f"✅ พบโมเดลใน Cache แล้ว: {folder_standard}") | |
| return MODEL_STANDARD | |
| # --- เรียกฟังก์ชันจัดการ Cache --- | |
| FINAL_MODEL_NAME = setup_model_cache() | |
| print(f"🚀 กำลังโหลดโมเดล (ที่ถูกปรับแต่งแล้ว): {FINAL_MODEL_NAME}...") | |
| # --- ติดตั้ง flash-attn (ถ้าจำเป็น) --- | |
| try: | |
| import flash_attn | |
| print("✅ flash_attn พร้อมใช้งาน") | |
| except ImportError: | |
| print("⚠️ flash_attn ไม่พบ กำลังติดตั้ง...") | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| # --- โหลดโมเดล --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| florence_model = None | |
| florence_processor = None | |
| try: | |
| florence_model = AutoModelForCausalLM.from_pretrained( | |
| FINAL_MODEL_NAME, trust_remote_code=True | |
| ).to(device).eval() | |
| florence_processor = AutoProcessor.from_pretrained( | |
| FINAL_MODEL_NAME, trust_remote_code=True | |
| ) | |
| print("✅ โหลดโมเดล Florence-2 (NSFW Version) เสร็จสิ้น!") | |
| except Exception as e: | |
| print(f"❌ เกิดข้อผิดพลาดในการโหลดโมเดล: {e}") | |
| florence_model = None | |
| florence_processor = None | |
| # --- รายการ Task ที่รองรับ (ตามที่คุณต้องการ) --- | |
| TASK_PROMPTS = { | |
| "Caption": "<CAPTION>", | |
| "Detailed Caption": "<DETAILED_CAPTION>", | |
| "More Detailed Caption": "<MORE_DETAILED_CAPTION>", | |
| "OCR": "<OCR>", | |
| "OCR with Region": "<OCR_WITH_REGION>", | |
| "Object Detection": "<OD>", | |
| "Dense Region Caption": "<DENSE_REGION_CAPTION>", | |
| "Region Proposal": "<REGION_PROPOSAL>", | |
| "Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>", | |
| "Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>", | |
| "Region to Segmentation": "<REGION_TO_SEGMENTATION>", | |
| "Open Vocabulary Detection": "<OPEN_VOCABULARY_DETECTION>", | |
| "Region to Category": "<REGION_TO_CATEGORY>", | |
| "Region to Description": "<REGION_TO_DESCRIPTION>", | |
| } | |
| def process_image(image, task_name): | |
| global florence_model, florence_processor | |
| if florence_model is None or florence_processor is None: | |
| return "❌ โมเดลยังไม่ได้โหลดหรือเกิดข้อผิดพลาดในการเริ่มต้น" | |
| if image is None: | |
| return "กรุณาเลือกรูปภาพ" | |
| try: | |
| # แปลง Gradio Image เป็น PIL Image | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # ดึง Prompt จาก Task ที่เลือก | |
| prompt = TASK_PROMPTS.get(task_name, "<CAPTION>") | |
| inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device) | |
| generated_ids = florence_model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = florence_processor.post_process_generation( | |
| generated_text, task=prompt, image_size=(image.width, image.height) | |
| ) | |
| # --- ส่วนแก้ไข: แยกค่าออกจาก Dictionary --- | |
| # parsed_answer จะเป็น dict เช่น {'<DETAILED_CAPTION>': 'ข้อความ...'} | |
| # เราต้องการแค่ 'ข้อความ...' | |
| if isinstance(parsed_answer, dict): | |
| # ดึงค่าแรก (Value) จาก Dictionary | |
| result_text = next(iter(parsed_answer.values())) | |
| else: | |
| # ถ้าไม่ใช่ dict (กรณีบาง task อาจ return string โดยตรง) | |
| result_text = str(parsed_answer) | |
| print(f"\n✅ Task: {task_name} | Result: {result_text}") | |
| return result_text | |
| except Exception as e: | |
| return f"❌ เกิดข้อผิดพลาดขณะประมวลผล: {str(e)}" | |
| # --- สร้าง UI --- | |
| with gr.Blocks(title="Image-to-Prompt (Florence-2)", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🖼️ Image-to-Prompt (Florence-2) | |
| อัปโหลดรูปภาพและเลือกประเภทการวิเคราะห์เพื่อสร้าง Prompt | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="รูปภาพ", height=400) | |
| task_dropdown = gr.Dropdown( | |
| choices=list(TASK_PROMPTS.keys()), | |
| value="More Detailed Caption", # เลือกเป็นค่าเริ่มต้น | |
| label="ประเภทการวิเคราะห์" | |
| ) | |
| btn = gr.Button("🚀 สร้างผลลัพธ์", variant="primary") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="ผลลัพธ์", lines=10, max_lines=20, show_copy_button=True) | |
| btn.click( | |
| fn=process_image, | |
| inputs=[image_input, task_dropdown], | |
| outputs=text_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |