R1000 commited on
Commit
d4bf08f
·
verified ·
1 Parent(s): 173507c

Rename app.py.bak to app.py

Browse files
Files changed (2) hide show
  1. app.py +148 -0
  2. app.py.bak +0 -69
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import torch
4
+ import os
5
+ import shutil
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, AutoModelForCausalLM
8
+ from huggingface_hub import snapshot_download
9
+
10
+ # --- ส่วนจัดการ Cache: ดึงโมเดล NSFW มาวางทับ Standard Model ---
11
+ MODEL_STANDARD = "microsoft/Florence-2-base"
12
+ MODEL_NSFW = "ljnlonoljpiljm/florence-2-base-nsfw-v2"
13
+
14
+ def setup_model_cache():
15
+ """
16
+ ดาวน์โหลดโมเดล NSFW มาวางทับโฟลเดอร์ของโมเดลมาตรฐานใน Cache
17
+ เพื่อให้ from_pretrained(MODEL_STANDARD) โหลดไฟล์ของ NSFW แทน
18
+ """
19
+ cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
20
+
21
+ # ชื่อโฟลเดอร์ใน cache (แปลง / เป็น --)
22
+ folder_standard = f"models--{MODEL_STANDARD.replace('/', '--')}"
23
+ folder_nsfw = f"models--{MODEL_NSFW.replace('/', '--')}"
24
+
25
+ path_standard = os.path.join(cache_dir, folder_standard)
26
+ path_nsfw = os.path.join(cache_dir, folder_nsfw)
27
+
28
+ print(f"🔍 ตรวจสอบ Cache: {folder_standard}")
29
+
30
+ # ถ้าโฟลเดอร์มาตรฐานยังไม่มี หรือต้องการบังคับอัปเดต
31
+ if not os.path.exists(path_standard) or not os.listdir(path_standard):
32
+ print(f"⚠️ ไม่พบโมเดลมาตรฐานใน Cache หรือโฟลเดอร์ว่าง")
33
+
34
+ # ถ้ามี NSFW แล้ว ก็ไม่ต้องโหลดใหม่
35
+ if os.path.exists(path_nsfw) and os.listdir(path_nsfw):
36
+ print(f"✅ พบโมเดล NSFW ใน Cache แล้ว: {folder_nsfw}")
37
+ source_path = path_nsfw
38
+ else:
39
+ print(f"🚀 กำลังดาวน์โหลดโมเดล NSFW ({MODEL_NSFW})...")
40
+ try:
41
+ # ดาวน์โหลดลงโฟลเดอร์ชั่วคราว (ใช้ชื่อ NSFW)
42
+ snapshot_download(
43
+ repo_id=MODEL_NSFW,
44
+ local_dir=path_nsfw,
45
+ local_dir_use_symlinks=False
46
+ )
47
+ print("✅ ดาวน์โหลดโมเดล NSFW เสร็จสิ้น")
48
+ source_path = path_nsfw
49
+ except Exception as e:
50
+ print(f"❌ ดาวน์โหลดล้มเหลว: {e}")
51
+ print("💡 ใช้โมเดลมาตรฐานแทน (อาจไม่มี NSFW filter)")
52
+ return MODEL_STANDARD # คืนชื่อมาตรฐานให้โหลดปกติ
53
+
54
+ # ลบโฟลเดอร์มาตรฐานเดิม (ถ้ามี)
55
+ if os.path.exists(path_standard):
56
+ print(f"🗑️ ลบโฟลเดอร์มาตรฐานเดิม: {folder_standard}")
57
+ shutil.rmtree(path_standard)
58
+
59
+ # Copy ไฟล์ทั้งหมดจาก NSFW มาทับ Standard
60
+ print(f"📂 กำลัง Copy ไฟล์จาก {folder_nsfw} -> {folder_standard}...")
61
+ shutil.copytree(source_path, path_standard)
62
+ print("✅ วางไฟล์ทับเสร็จสิ้น! ตอนนี้ from_pretrained('microsoft/Florence-2-base') จะใช้ไฟล์ของ NSFW")
63
+
64
+ else:
65
+ print(f"✅ พบโมเดลใน Cache แล้ว: {folder_standard}")
66
+ # ตรวจสอบว่าไฟล์ภายในเป็นของ NSFW หรือไม่ (โดยดูไฟล์ config.json ถ้ามี)
67
+ # หากต้องการบังคับอัปเดตทุกครั้ง สามารถลบโฟลเดอร์นี้ทิ้งได้
68
+
69
+ return MODEL_STANDARD
70
+
71
+ # --- เรียกฟังก์ชันจัดการ Cache ก่อนโหลดโมเดล ---
72
+ FINAL_MODEL_NAME = setup_model_cache()
73
+ print(f"🚀 กำลังโหลดโมเดล (ที่ถูกปรับแต่งแล้ว): {FINAL_MODEL_NAME}...")
74
+
75
+ # ติดตั้ง flash-attn (ถ้าจำเป็น)
76
+ try:
77
+ import flash_attn
78
+ print("✅ flash_attn พร้อมใช้งาน")
79
+ except ImportError:
80
+ print("⚠️ flash_attn ไม่พบ กำลังติดตั้ง...")
81
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
82
+
83
+ # --- โหลดโมเดล ---
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ florence_model = None
86
+ florence_processor = None
87
+
88
+ try:
89
+ # โหลดโมเดล (ตอนนี้จะโหลดไฟล์ที่เราก๊อ��ปี้มาแทนที่)
90
+ florence_model = AutoModelForCausalLM.from_pretrained(
91
+ FINAL_MODEL_NAME,
92
+ trust_remote_code=True
93
+ ).to(device).eval()
94
+
95
+ florence_processor = AutoProcessor.from_pretrained(
96
+ FINAL_MODEL_NAME,
97
+ trust_remote_code=True
98
+ )
99
+ print("✅ โหลดโมเดล Florence-2 (NSFW Version) เสร็จสิ้น!")
100
+ except Exception as e:
101
+ print(f"❌ เกิดข้อผิดพลาดในการโหลดโมเดล: {e}")
102
+ print("💡 ตรวจสอบ Logs เพื่อดูรายละเอียด")
103
+ florence_model = None
104
+ florence_processor = None
105
+
106
+ def generate_caption(image):
107
+ global florence_model, florence_processor
108
+
109
+ if florence_model is None or florence_processor is None:
110
+ return "❌ โมเดลยังไม่ได้โหลดหรือเกิดข้อผิดพลาดในการเริ่มต้น"
111
+
112
+ if not isinstance(image, Image.Image):
113
+ image = Image.fromarray(image)
114
+
115
+ try:
116
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
117
+ generated_ids = florence_model.generate(
118
+ input_ids=inputs["input_ids"],
119
+ pixel_values=inputs["pixel_values"],
120
+ max_new_tokens=1024,
121
+ early_stopping=False,
122
+ do_sample=False,
123
+ num_beams=3,
124
+ )
125
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
126
+ parsed_answer = florence_processor.post_process_generation(
127
+ generated_text,
128
+ task="<MORE_DETAILED_CAPTION>",
129
+ image_size=(image.width, image.height)
130
+ )
131
+ prompt = parsed_answer["<MORE_DETAILED_CAPTION>"]
132
+ print("\n\nGeneration completed!:" + prompt)
133
+ return prompt
134
+ except Exception as e:
135
+ return f"❌ เกิดข้อผิดพลาดขณะประมวลผล: {str(e)}"
136
+
137
+ # --- สร้าง UI ---
138
+ io = gr.Interface(
139
+ fn=generate_caption,
140
+ inputs=[gr.Image(label="Input Image")],
141
+ outputs=[gr.Textbox(label="Output Prompt", lines=2, show_copy_button=True)],
142
+ deep_link=False,
143
+ title="Image-to-Prompt (Florence-2 NSFW)",
144
+ description="อัปโหลดรูปภาพเพื่อสร้าง Prompt (ใช้โมเดล NSFW ที่ปรับแต่งแล้ว)"
145
+ )
146
+
147
+ if __name__ == "__main__":
148
+ io.launch(debug=True)
app.py.bak DELETED
@@ -1,69 +0,0 @@
1
- import gradio as gr
2
- import subprocess
3
- import torch
4
- from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM
6
-
7
- # import os
8
- # import random
9
- # from gradio_client import Client
10
-
11
-
12
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
-
14
- # Initialize Florence model
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
17
- florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
18
-
19
- # api_key = os.getenv("HF_READ_TOKEN")
20
-
21
- def generate_caption(image):
22
- if not isinstance(image, Image.Image):
23
- image = Image.fromarray(image)
24
-
25
- inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
26
- generated_ids = florence_model.generate(
27
- input_ids=inputs["input_ids"],
28
- pixel_values=inputs["pixel_values"],
29
- max_new_tokens=1024,
30
- early_stopping=False,
31
- do_sample=False,
32
- num_beams=3,
33
- )
34
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
35
- parsed_answer = florence_processor.post_process_generation(
36
- generated_text,
37
- task="<MORE_DETAILED_CAPTION>",
38
- image_size=(image.width, image.height)
39
- )
40
- prompt = parsed_answer["<MORE_DETAILED_CAPTION>"]
41
- print("\n\nGeneration completed!:"+ prompt)
42
- return prompt
43
- # yield prompt, None
44
- # image_path = generate_image(prompt,random.randint(0, 4294967296))
45
- # yield prompt, image_path
46
-
47
- # def generate_image(prompt, seed=42, width=1024, height=1024):
48
- # try:
49
- # result = Client("KingNish/Realtime-FLUX", hf_token=api_key).predict(
50
- # prompt=prompt,
51
- # seed=seed,
52
- # width=width,
53
- # height=height,
54
- # api_name="/generate_image"
55
- # )
56
- # # Extract the image path from the result tuple
57
- # image_path = result[0]
58
- # return image_path
59
- # except Exception as e:
60
- # raise Exception(f"Error generating image: {str(e)}")
61
-
62
- io = gr.Interface(generate_caption,
63
- inputs=[gr.Image(label="Input Image")],
64
- outputs = [gr.Textbox(label="Output Prompt", lines=2, show_copy_button = True),
65
- # gr.Image(label="Output Image")
66
- ],
67
- deep_link=False
68
- )
69
- io.launch(debug=True)