lacos03 commited on
Commit
c1253b3
·
1 Parent(s): d0f68a4

Cho phép sinh nhieu anh

Browse files
Files changed (1) hide show
  1. app.py +127 -165
app.py CHANGED
@@ -2,189 +2,151 @@ import gradio as gr
2
  import torch
3
  from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer
4
  from diffusers import StableDiffusionPipeline
5
- import io
6
  from PIL import Image
7
- import traceback
8
  import os
9
- from pathlib import Path
 
10
 
11
- # === Thiết lập môi trường ===
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Device: {device}")
14
 
15
- # === Load models với xử lý lỗi ===
16
- try:
17
- # Summarizer (BART)
18
- model_name = "lacos03/bart-base-finetuned-xsum"
19
- print(f"Loading BART model from {model_name}...")
20
- tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
22
- model.to(device)
23
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
24
- print("✅ BART loaded successfully")
25
- except Exception as e:
26
- print(f"❌ Error loading BART: {e}")
27
- summarizer = None
28
-
29
- try:
30
- # Promptist
31
- print("Loading Promptist model...")
32
- def load_prompter():
33
- prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
34
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
35
- tokenizer.pad_token = tokenizer.eos_token
36
- tokenizer.padding_side = "left"
37
- return prompter_model, tokenizer
38
- promptist_model, promptist_tokenizer = load_prompter()
39
- print("✅ Promptist loaded successfully")
40
- except Exception as e:
41
- print(f"❌ Error loading Promptist: {e}")
42
- promptist_model = None
43
- promptist_tokenizer = None
44
-
45
- try:
46
- # Stable Diffusion + LoRA
47
- print("Loading Stable Diffusion model...")
48
- sd_model_id = "runwayml/stable-diffusion-v1-5"
49
- image_generator = StableDiffusionPipeline.from_pretrained(
50
- sd_model_id,
51
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
52
- use_safetensors=True
53
- ).to(device)
54
- lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
55
- print(f"Loading LoRA weights from {lora_weights}...")
56
- image_generator.load_lora_weights(lora_weights)
57
- print("✅ Stable Diffusion with LoRA loaded successfully")
58
- except Exception as e:
59
- print(f"❌ Error loading Stable Diffusion or LoRA: {e}")
60
- image_generator = None
61
-
62
- # === Modular hóa ===
63
- def summarize(article_text):
64
- if not summarizer or not article_text.strip():
65
- return "[Empty input or model not loaded]", "[Empty input or model not loaded]"
66
- try:
67
- summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
68
- title = summary.split(".")[0] + "."
69
- return title, summary
70
- except Exception as e:
71
- return f"[Error in summarization: {e}]", f"[Error in summarization: {e}]"
72
-
73
- def generate_prompt(title):
74
- if not promptist_model or not promptist_tokenizer or not title:
75
- return "[Error: Promptist not loaded or no title]"
76
- try:
77
- input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device)
78
- eos_id = promptist_tokenizer.eos_token_id
79
- outputs = promptist_model.generate(
80
- input_ids,
81
- do_sample=False,
82
- max_new_tokens=75,
83
- num_beams=8,
84
- num_return_sequences=8,
85
- eos_token_id=eos_id,
86
- pad_token_id=eos_id,
87
- length_penalty=-1.0
88
- )
89
- output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
90
- prompt = output_texts[0].replace(title + " Rephrase:", "").strip()
91
- return prompt
92
- except Exception as e:
93
- return f"[Error in prompt generation: {e}]"
94
-
95
- def generate_image(prompt, style):
96
- if not image_generator or not prompt:
97
- blank = Image.new("RGB", (512, 512), (255, 255, 255))
98
- img_byte_arr = io.BytesIO()
99
- blank.save(img_byte_arr, format="PNG")
100
- img_byte_arr.seek(0)
101
- return blank, img_byte_arr
102
- try:
103
- styled_prompt = f"{prompt}, {style.lower()} style"
104
- result = image_generator(
105
- styled_prompt,
106
- num_inference_steps=50,
107
- guidance_scale=7.5
108
- ).images[0]
109
- img_byte_arr = io.BytesIO()
110
- result.save(img_byte_arr, format="PNG")
111
- img_byte_arr.seek(0)
112
- return result, img_byte_arr
113
- except Exception as e:
114
- print(f"❌ Image generation error: {traceback.format_exc()}")
115
- blank = Image.new("RGB", (512, 512), (255, 255, 255))
116
- img_byte_arr = io.BytesIO()
117
- blank.save(img_byte_arr, format="PNG")
118
- img_byte_arr.seek(0)
119
- return blank, img_byte_arr
120
-
121
- # === Main processing function with staged outputs ===
122
- def process_step_by_step(article_text, style_choice, state=None):
123
- if state is None:
124
- state = {"title": None, "prompt": None, "image": None, "file_path": None}
125
-
126
- # Bước 1: Tóm tắt và tạo tiêu đề
127
- title, summary = summarize(article_text)
128
- print(f"Summary title: {title}")
129
- state["title"] = title
130
- yield state, title, None, None, None
131
-
132
- # Bước 2: Tạo prompt
133
- prompt = generate_prompt(title)
134
- print(f"Generated prompt: {prompt}")
135
- state["prompt"] = prompt
136
- yield state, title, prompt, None, None
137
-
138
- # Bước 3: Tạo ảnh
139
- image, img_bytes = generate_image(prompt, style_choice)
140
- print(f"Image generated: {image.size if image else 'None'}")
141
-
142
- # Lưu ảnh tạm thời
143
- temp_dir = "./temp"
144
  os.makedirs(temp_dir, exist_ok=True)
145
- temp_file = os.path.join(temp_dir, f"generated_image_{id(image)}.png")
146
- image.save(temp_file, format="PNG")
147
- state["image"] = image
148
- state["file_path"] = temp_file
149
-
150
- print(f"✅ Process completed")
151
- yield state, title, prompt, image, temp_file
152
-
153
- # === Gradio UI ===
 
154
  def create_app():
155
  with gr.Blocks() as demo:
156
- gr.Markdown("## 📰 Article → 🖼️ Image Generator")
157
- gr.Markdown("Nhập bài viết → Sinh tiêu đề → Tối ưu prompt → Sinh ảnh minh họa tự động")
158
-
159
- # State để lưu trữ trạng thái tạm thời
160
- state = gr.State(value=None)
161
 
 
162
  with gr.Row():
163
- article_input = gr.Textbox(label="📄 Bài viết", lines=10, placeholder="Dán nội dung bài viết ở đây...")
164
- style_dropdown = gr.Dropdown(choices=["Art", "Anime", "Watercolor", "Cyberpunk"], label="🎨 Phong cách ảnh", value="Art")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- with gr.Row():
167
- submit_button = gr.Button("🚀 Tạo Tiêu đề & Ảnh Minh họa")
 
 
168
 
169
- with gr.Row():
170
- title_output = gr.Textbox(label="📌 Tiêu đề được tạo")
171
- prompt_output = gr.Textbox(label="🔧 Prompt sinh ảnh")
172
 
173
- image_output = gr.Image(label="🖼️ Ảnh minh họa", interactive=True)
174
- download_button = gr.File(label="📥 Tải ảnh")
175
-
176
- feedback = gr.Radio(["👍 Hài lòng", "👎 Không hài lòng"], label="📊 Bạn có hài lòng với kết quả không?", value=None)
 
177
 
178
- # Gắn sự kiện nút submit với hàm process từng bước
179
- submit_button.click(
180
- fn=process_step_by_step,
181
- inputs=[article_input, style_dropdown, state],
182
- outputs=[state, title_output, prompt_output, image_output, download_button]
183
  )
184
 
185
  return demo
186
 
187
- # === Launch ===
188
  if __name__ == "__main__":
189
  app = create_app()
190
- app.launch(debug=True, share=True)
 
2
  import torch
3
  from transformers import pipeline, AutoModelForSeq2SeqLM, BartTokenizer, AutoModelForCausalLM, AutoTokenizer
4
  from diffusers import StableDiffusionPipeline
 
5
  from PIL import Image
6
+ import io
7
  import os
8
+ import zipfile
9
+ import traceback
10
 
11
+ # === Thiết lập thiết bị ===
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Device: {device}")
14
 
15
+ # === Load models ===
16
+ # BART Summarizer
17
+ model_name = "lacos03/bart-base-finetuned-xsum"
18
+ tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
22
+ ).to(device)
23
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1)
24
+
25
+ # Promptist
26
+ promptist_model = AutoModelForCausalLM.from_pretrained(
27
+ "microsoft/Promptist",
28
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
29
+ ).to(device)
30
+ promptist_tokenizer = AutoTokenizer.from_pretrained("gpt2")
31
+ promptist_tokenizer.pad_token = promptist_tokenizer.eos_token
32
+ promptist_tokenizer.padding_side = "left"
33
+
34
+ # Stable Diffusion + LoRA
35
+ sd_model_id = "runwayml/stable-diffusion-v1-5"
36
+ image_generator = StableDiffusionPipeline.from_pretrained(
37
+ sd_model_id,
38
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
39
+ use_safetensors=True
40
+ ).to(device)
41
+ lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
42
+ image_generator.load_lora_weights(lora_weights)
43
+
44
+ # === Hàm xử lý ===
45
+ def summarize_article(article_text):
46
+ """Tóm tắt bài viết và tạo prompt refinement"""
47
+ if not article_text.strip():
48
+ return "[Empty]", "[Empty]"
49
+ summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
50
+ title = summary.split(".")[0] + "."
51
+ # Prompt refinement
52
+ input_ids = promptist_tokenizer(title.strip() + " Rephrase:", return_tensors="pt").input_ids.to(device)
53
+ eos_id = promptist_tokenizer.eos_token_id
54
+ outputs = promptist_model.generate(
55
+ input_ids,
56
+ do_sample=False,
57
+ max_new_tokens=75,
58
+ num_beams=8,
59
+ num_return_sequences=1,
60
+ eos_token_id=eos_id,
61
+ pad_token_id=eos_id,
62
+ length_penalty=-1.0
63
+ )
64
+ output_texts = promptist_tokenizer.batch_decode(outputs, skip_special_tokens=True)
65
+ prompt = output_texts[0].replace(title + " Rephrase:", "").strip()
66
+ return title, prompt
67
+
68
+ def generate_images(prompt, style, num_images=4):
69
+ """Sinh nhiều ảnh"""
70
+ styled_prompt = f"{prompt}, {style.lower()} style"
71
+ results = image_generator(
72
+ styled_prompt,
73
+ num_inference_steps=50,
74
+ guidance_scale=7.5,
75
+ num_images_per_prompt=num_images
76
+ ).images
77
+ return results
78
+
79
+ def save_selected_images(selected_idx, all_images):
80
+ """Lưu ảnh đã chọn và nén thành ZIP"""
81
+ if not selected_idx:
82
+ return None
83
+ temp_dir = "./temp_selected"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  os.makedirs(temp_dir, exist_ok=True)
85
+ zip_path = os.path.join(temp_dir, "selected_images.zip")
86
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
87
+ for idx in selected_idx:
88
+ img = all_images[int(idx)]
89
+ img_path = os.path.join(temp_dir, f"image_{idx}.png")
90
+ img.save(img_path, format="PNG")
91
+ zipf.write(img_path, f"image_{idx}.png")
92
+ return zip_path
93
+
94
+ # === UI Gradio ===
95
  def create_app():
96
  with gr.Blocks() as demo:
97
+ gr.Markdown("## 📰 Article → 🖼️ Multiple Image Generator with Selection")
 
 
 
 
98
 
99
+ # Bước 1: Nhập bài viết và sinh tiêu đề + prompt
100
  with gr.Row():
101
+ article_input = gr.Textbox(label="📄 Bài viết", lines=10)
102
+ style_dropdown = gr.Dropdown(
103
+ choices=["Art", "Anime", "Watercolor", "Cyberpunk"],
104
+ label="🎨 Phong cách ảnh", value="Art"
105
+ )
106
+ num_images_slider = gr.Slider(1, 8, value=4, step=1, label="🔢 Số lượng ảnh")
107
+ btn_summary = gr.Button("📌 Sinh Tiêu đề & Prompt")
108
+
109
+ title_output = gr.Textbox(label="Tiêu đề")
110
+ prompt_output = gr.Textbox(label="Prompt sinh ảnh")
111
+
112
+ # Bước 2: Sinh ảnh từ prompt đã refine
113
+ btn_generate_images = gr.Button("🎨 Sinh ảnh từ Prompt")
114
+ gallery = gr.Gallery(label="🖼️ Ảnh minh họa", columns=2, height=600)
115
+ selected_indices = gr.CheckboxGroup(choices=[], label="Chọn ảnh để tải về")
116
+
117
+ # Bước 3: Tải ảnh đã chọn
118
+ btn_download = gr.Button("📥 Tải ảnh đã chọn")
119
+ download_file = gr.File(label="File ZIP tải về")
120
+
121
+ # Logic
122
+ btn_summary.click(
123
+ fn=summarize_article,
124
+ inputs=[article_input],
125
+ outputs=[title_output, prompt_output]
126
+ )
127
 
128
+ def update_gallery(prompt, style, num_images):
129
+ images = generate_images(prompt, style, num_images)
130
+ choices = [str(i) for i in range(len(images))]
131
+ return images, gr.update(choices=choices, value=[]), images # images lưu tạm trong state
132
 
133
+ image_state = gr.State([])
 
 
134
 
135
+ btn_generate_images.click(
136
+ fn=update_gallery,
137
+ inputs=[prompt_output, style_dropdown, num_images_slider],
138
+ outputs=[gallery, selected_indices, image_state]
139
+ )
140
 
141
+ btn_download.click(
142
+ fn=save_selected_images,
143
+ inputs=[selected_indices, image_state],
144
+ outputs=[download_file]
 
145
  )
146
 
147
  return demo
148
 
149
+ # === Chạy app ===
150
  if __name__ == "__main__":
151
  app = create_app()
152
+ app.launch(debug=True, share=True)