userIdc2024 commited on
Commit
7ff26ab
·
verified ·
1 Parent(s): 2876716

Create multiodel_image_processor.py

Browse files
generator_function/multiodel_image_processor.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile, tempfile, logging, base64
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ from typing import List, Tuple, Optional
4
+ from PIL import Image
5
+ import io
6
+
7
+ from generator_function.image_function import generate_image
8
+ from prompt.prompt_services import get_prompts
9
+ from multimodel_services.replicate_generation_service import generate_image_with_model, convert_size_to_aspect_ratio
10
+ from multimodel_services.model_manager import is_gpt_model, get_all_parameters
11
+ from helpers_function.helper_meta_data import meta_data_helper_function
12
+ from helpers_function.helpers import upload_image_to_r2
13
+ from helpers_function.helpers import is_valid_image
14
+ from database.connections import get_results_collection as get_collection
15
+ from database.operations import start_job, finish_job
16
+ from util.session_state import current_uid
17
+
18
+ logger = logging.getLogger(__name__)
19
+ COL = get_collection()
20
+
21
+ def _resolve_user_id() -> str:
22
+ return current_uid() or os.getenv("DEFAULT_USER_ID", "anonymous")
23
+
24
+ def process_zip_and_generate_images_multimodel(
25
+ zip_path: str,
26
+ category: str,
27
+ size: str,
28
+ quality: str,
29
+ user_prompt: str,
30
+ sentiment: str,
31
+ platform: str,
32
+ num_images: int,
33
+ demo_mode: bool,
34
+ existing_images: Optional[List[str]],
35
+ blur: bool,
36
+ uid: str,
37
+ selected_model: str = "gpt_default",
38
+ model_params: Optional[dict] = None,
39
+ ) -> List[str]:
40
+ """Enhanced image processor that supports both GPT and multimodel approaches"""
41
+ num_images = 1 if demo_mode else num_images
42
+ try:
43
+ if zip_path.endswith(".zip"):
44
+ temp_dir = extract_zip_file(zip_path)
45
+ image_files = get_valid_image_files(temp_dir)
46
+ else:
47
+ image_files = [(os.path.basename(zip_path), zip_path)]
48
+
49
+ results = process_image_files_multimodel(
50
+ image_files, category, size, quality, user_prompt, sentiment, platform,
51
+ num_images, blur, uid, selected_model, model_params
52
+ )
53
+ all_urls = [url for entry in results for url in entry["urls"]]
54
+ seen, deduped = set(), []
55
+ for u in all_urls:
56
+ if u not in seen:
57
+ seen.add(u); deduped.append(u)
58
+ return (existing_images or []) + deduped
59
+ except Exception:
60
+ logger.exception(f"Global error during processing file: {zip_path}")
61
+ return existing_images or []
62
+
63
+ def extract_zip_file(zip_path: str) -> tempfile.TemporaryDirectory:
64
+ temp_dir = tempfile.TemporaryDirectory()
65
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
66
+ zip_ref.extractall(temp_dir.name)
67
+ logger.info(f"Extracted ZIP file: {zip_path}")
68
+ return temp_dir
69
+
70
+ def get_valid_image_files(temp_dir: tempfile.TemporaryDirectory) -> List[Tuple[str, str]]:
71
+ valid_files: List[Tuple[str, str]] = []
72
+ for file in os.listdir(temp_dir.name):
73
+ if "__MACOSX" in file: continue
74
+ file_path = os.path.join(temp_dir.name, file)
75
+ if is_valid_image(file):
76
+ valid_files.append((file, file_path))
77
+ else:
78
+ logger.warning(f"Ignored non-image file: {file}")
79
+ logger.info(f"Found {len(valid_files)} valid images.")
80
+ return valid_files
81
+
82
+ def process_image_files_multimodel(image_files: List[Tuple[str, str]], category: str, size: str,
83
+ quality: str, user_prompt: str, sentiment: str, platform: str, num_images: int, blur: bool,
84
+ uid: str, selected_model: str, model_params: Optional[dict]) -> List[dict]:
85
+ """Process image files with multimodel support"""
86
+ final_results: List[dict] = []
87
+ with ThreadPoolExecutor(max_workers=5) as executor:
88
+ futures = []
89
+ for file_name, file_path in image_files:
90
+ job_id: Optional[str] = None
91
+ if COL is not None:
92
+ try:
93
+ settings = {
94
+ "size": size, "quality": quality, "sentiment": sentiment,
95
+ "platform": platform, "num_images": num_images, "blur": bool(blur),
96
+ "selected_model": selected_model, "model_params": model_params or {}
97
+ }
98
+ inputs = {"file_name": file_name, "mode": "img_or_zip_multimodel"}
99
+ job_id = start_job(
100
+ COL,
101
+ type="variation_multimodel",
102
+ created_by=uid,
103
+ category=category or "general",
104
+ inputs=inputs,
105
+ settings=settings,
106
+ user_prompt=user_prompt
107
+ )
108
+ except Exception:
109
+ logger.exception("Failed to start DB job; continuing without DB logging.")
110
+ futures.append(
111
+ executor.submit(
112
+ process_single_image_multimodel,
113
+ file_name, file_path, category, size, quality, user_prompt, sentiment,
114
+ platform, num_images, blur, job_id, selected_model, model_params,
115
+ )
116
+ )
117
+ for future in as_completed(futures):
118
+ try:
119
+ result = future.result()
120
+ if result: final_results.append(result)
121
+ except Exception:
122
+ logger.exception("Unhandled exception during image processing thread.")
123
+ return final_results
124
+
125
+ def process_single_image_multimodel(file_name: str, file_path: str, category: str, size: str,
126
+ quality: str, user_prompt: str, sentiment: str, platform: str, num_images: int, blur: bool,
127
+ job_id: Optional[str], selected_model: str, model_params: Optional[dict]) -> Optional[dict]:
128
+ """Process single image with multimodel support"""
129
+ try:
130
+ image_urls = generate_images_from_prompts_multimodel(
131
+ file_path, size, quality, category, sentiment, user_prompt, platform,
132
+ num_images, blur, selected_model, model_params
133
+ )
134
+ if COL is not None and job_id:
135
+ try:
136
+ finish_job(COL, job_id, status=("completed" if image_urls else "failed"), outputs_urls=image_urls)
137
+ except Exception:
138
+ logger.exception("Failed to finish DB job.")
139
+ if image_urls:
140
+ return {"file_name": file_name, "urls": image_urls}
141
+ return None
142
+ except Exception as e:
143
+ logger.error(f"Processing failed for {file_name}: {e}")
144
+ if COL is not None and job_id:
145
+ try:
146
+ finish_job(COL, job_id, status="failed", outputs_urls=[])
147
+ except Exception:
148
+ logger.exception("Also failed to mark DB job as failed.")
149
+ return None
150
+
151
+ def generate_images_from_prompts_multimodel(
152
+ file_path: str, size: str, quality: str, category: str, sentiment: str, user_prompt: str,
153
+ platform: str, num_images: int, blur: bool, selected_model: str, model_params: Optional[dict],
154
+ ) -> List[str]:
155
+ """Generate images using either GPT or multimodel approach"""
156
+ image_urls: List[str] = []
157
+
158
+ def worker(i: int) -> Optional[str]:
159
+ try:
160
+ if is_gpt_model(selected_model):
161
+ # Use existing GPT approach
162
+ image_bytes = generate_image(file_path, size, quality, category, sentiment, user_prompt, platform, blur, i)
163
+ else:
164
+ # Use multimodel approach
165
+ image_bytes = generate_image_multimodel(
166
+ file_path, selected_model, category, sentiment, user_prompt,
167
+ platform, size, model_params, i
168
+ )
169
+
170
+ if not image_bytes: return None
171
+ image_with_metadata = meta_data_helper_function(image_bytes)
172
+ s3_url = upload_image_to_r2(image_with_metadata)
173
+ return s3_url
174
+ except Exception as e:
175
+ logger.error(f"Image generation failed: {e}")
176
+ return None
177
+
178
+ with ThreadPoolExecutor(max_workers=min(10, num_images)) as executor:
179
+ futures = [executor.submit(worker, i) for i in range(num_images)]
180
+ for future in as_completed(futures):
181
+ result = future.result()
182
+ if result: image_urls.append(result)
183
+ return image_urls
184
+
185
+ def generate_image_multimodel(file_path: str, model_name: str, category: str, sentiment: str,
186
+ user_prompt: str, platform: str, size: str, model_params: Optional[dict],
187
+ variation_index: int) -> Optional[bytes]:
188
+ """Generate image using multimodel approach"""
189
+ try:
190
+ # Convert image to base64
191
+ with open(file_path, 'rb') as f:
192
+ image_data = f.read()
193
+ base64_image = base64.b64encode(image_data).decode()
194
+
195
+ # Use existing prompt service to get prompt variations
196
+ prompt_variations = get_prompts(
197
+ base64_image, category, user_prompt, sentiment, None
198
+ )
199
+
200
+ # Select a prompt based on variation index
201
+ if prompt_variations and len(prompt_variations) > 0:
202
+ selected_prompt = prompt_variations[variation_index % len(prompt_variations)]
203
+ else:
204
+ # Fallback prompt
205
+ selected_prompt = f"Generate a high-quality {category or 'advertising'} image. {user_prompt}"
206
+
207
+ # Prepare model parameters
208
+ all_params = get_all_parameters(model_name, model_params)
209
+
210
+ # Only convert size to aspect ratio if no aspect ratio was provided by user
211
+ if model_name == "google/nano-banana" and "aspect_ratio" in all_params:
212
+ # If aspect_ratio is the default value, convert from size
213
+ if all_params["aspect_ratio"] == "match_input_image": # This is the default
214
+ converted_ratio = convert_size_to_aspect_ratio(size, model_name)
215
+ all_params["aspect_ratio"] = converted_ratio
216
+ logger.info(f"Converted size '{size}' to aspect ratio '{converted_ratio}' for {model_name}")
217
+ else:
218
+ logger.info(f"Using user-selected aspect ratio '{all_params['aspect_ratio']}' for {model_name}")
219
+
220
+ logger.info(f"Final parameters for {model_name}: {all_params}")
221
+
222
+ # Generate image with selected model
223
+ generated_image_data = generate_image_with_model(
224
+ model_name, selected_prompt, all_params, base64_image
225
+ )
226
+
227
+ return generated_image_data
228
+
229
+ except Exception as e:
230
+ logger.error(f"Multimodel generation failed: {e}")
231
+ return None