um41r commited on
Commit
700b060
·
verified ·
1 Parent(s): fcaff5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -133
app.py CHANGED
@@ -3,41 +3,35 @@ import sys
3
  import gradio as gr
4
  import numpy as np
5
  from PIL import Image
6
- import onnxruntime as ort
7
- from huggingface_hub import hf_hub_download
8
- import torch
9
- from typing import Dict, List, Optional
10
- import json
11
  import warnings
12
  import logging
 
13
 
14
- # Suppress warnings
15
  warnings.filterwarnings("ignore")
16
  logging.getLogger("transformers").setLevel(logging.ERROR)
17
 
18
  # Configuration
19
  MODEL_REPO = "onnx-community/Florence-2-base"
20
-
21
- # Use non-merged decoder to avoid the subgraph output issue
22
- # decoder_model_merged has a bug with outer scope values
23
  ONNX_FILES = {
24
  "vision_encoder": "vision_encoder_fp16.onnx",
25
  "embed_tokens": "embed_tokens_fp16.onnx",
26
  "encoder_model": "encoder_model_fp16.onnx",
27
- "decoder_model": "decoder_model_fp16.onnx", # Changed from merged to standard
28
- "decoder_with_past_model": "decoder_with_past_model_fp16.onnx" # For efficient generation
29
  }
30
 
31
- # Global variables for models
32
  sessions = {}
33
  processor = None
34
  tokenizer = None
 
35
 
36
  def download_models():
37
  """Download ONNX models from HuggingFace Hub"""
38
- print("📥 Downloading ONNX models (FP16)...")
39
- model_paths = {}
40
 
 
41
  os.makedirs("./models/onnx", exist_ok=True)
42
 
43
  for name, filename in ONNX_FILES.items():
@@ -48,34 +42,35 @@ def download_models():
48
  local_dir="./models",
49
  local_dir_use_symlinks=False
50
  )
51
- model_paths[name] = path
52
  size_mb = os.path.getsize(path) / (1024 * 1024)
53
  print(f" ✓ {name}: {size_mb:.1f}MB")
54
 
55
- return model_paths
56
 
57
  def init_models():
58
  """Initialize ONNX Runtime sessions"""
59
- global sessions, processor, tokenizer
60
 
61
- # Download models if not present
 
 
 
 
 
 
 
 
 
 
 
62
  if not all(os.path.exists(f"./models/onnx/{f}") for f in ONNX_FILES.values()):
63
  download_models()
64
  else:
65
  print("✅ Models already cached")
66
 
67
- # FORCE CPU-ONLY execution
68
  providers = ['CPUExecutionProvider']
69
- print(f"Using providers: {providers} (CPU-only mode)")
70
-
71
- # Import transformers
72
- try:
73
- from transformers import AutoProcessor, AutoTokenizer, BartTokenizerFast
74
- import transformers
75
- print(f"Transformers version: {transformers.__version__}")
76
- except ImportError as e:
77
- print(f"Error importing transformers: {e}")
78
- raise
79
 
80
  # Load processor and tokenizer
81
  print("📥 Loading processor and tokenizer...")
@@ -87,7 +82,7 @@ def init_models():
87
  use_fast=False
88
  )
89
  except Exception as e:
90
- print(f"Error loading processor: {e}")
91
  processor = AutoProcessor.from_pretrained(
92
  "microsoft/Florence-2-base",
93
  trust_remote_code=True
@@ -110,8 +105,6 @@ def init_models():
110
 
111
  sess_options = ort.SessionOptions()
112
  sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
113
- sess_options.enable_cpu_mem_arena = True
114
- sess_options.enable_mem_pattern = True
115
 
116
  try:
117
  sessions[name] = ort.InferenceSession(
@@ -126,106 +119,94 @@ def init_models():
126
 
127
  print("✅ All models loaded successfully!")
128
 
129
- def generate_caption(image: Image.Image, task: str = "<MORE_DETAILED_CAPTION>", max_new_tokens: int = 256):
130
- """Generate caption using Florence-2 ONNX models with separate decoder"""
131
 
132
- # Prepare inputs using processor
133
  inputs = processor(text=task, images=image, return_tensors="np")
134
-
135
- # Get shapes
136
  batch_size = 1
137
 
138
- # 1. Run Vision Encoder
139
  pixel_values = inputs["pixel_values"].astype(np.float16)
140
  vision_outputs = sessions["vision_encoder"].run(None, {"pixel_values": pixel_values})
141
  image_features = vision_outputs[0]
142
 
143
- # 2. Run Embed Tokens
144
  input_ids = inputs["input_ids"].astype(np.int64)
145
  embed_outputs = sessions["embed_tokens"].run(None, {"input_ids": input_ids})
146
  text_embeds = embed_outputs[0]
147
 
148
- # 3. Concatenate image and text embeddings for encoder
149
  combined_embeds = np.concatenate([image_features, text_embeds], axis=1)
150
-
151
- # Create attention mask for combined sequence
152
  vision_seq_len = image_features.shape[1]
153
  text_seq_len = text_embeds.shape[1]
154
  combined_seq_len = vision_seq_len + text_seq_len
155
-
156
  encoder_attention_mask = np.ones((batch_size, combined_seq_len), dtype=np.int64)
157
 
158
- # 4. Run Encoder
159
  encoder_outputs = sessions["encoder_model"].run(None, {
160
  "inputs_embeds": combined_embeds.astype(np.float16),
161
  "attention_mask": encoder_attention_mask
162
  })
163
  encoder_hidden_states = encoder_outputs[0]
164
 
165
- # 5. Generation with separate decoder models
166
- # Use decoder_model for first step, decoder_with_past_model for subsequent steps
167
  generated_ids = input_ids.copy()
168
- past_key_values = None
169
 
170
  for i in range(max_new_tokens):
171
- if past_key_values is None:
172
- # First iteration - use decoder_model (no past)
173
  decoder_inputs = {
174
- "input_ids": generated_ids, # Full sequence for first step
175
  "encoder_hidden_states": encoder_hidden_states.astype(np.float16),
176
  "encoder_attention_mask": encoder_attention_mask.astype(np.int64)
177
  }
178
-
179
  decoder_outputs = sessions["decoder_model"].run(None, decoder_inputs)
180
  logits = decoder_outputs[0]
181
-
182
- # Extract past key values for next iteration (if provided by model)
183
- if len(decoder_outputs) > 1:
184
- past_key_values = decoder_outputs[1:]
185
  else:
186
- # Subsequent iterations - use decoder_with_past_model
187
- # Only feed the last token
188
  decoder_inputs = {
189
- "input_ids": generated_ids[:, -1:], # Last token only
190
  "encoder_hidden_states": encoder_hidden_states.astype(np.float16),
191
- "encoder_attention_mask": encoder_attention_mask.astype(np.int64),
192
- "past_key_values": past_key_values
193
  }
194
-
195
  decoder_outputs = sessions["decoder_with_past_model"].run(None, decoder_inputs)
196
  logits = decoder_outputs[0]
197
-
198
- # Update past key values
199
- if len(decoder_outputs) > 1:
200
- past_key_values = decoder_outputs[1:]
201
 
202
- # Get next token (greedy decoding) - use last position
203
  next_token_logits = logits[:, -1, :]
204
  next_token_id = np.argmax(next_token_logits, axis=-1, keepdims=True)
205
-
206
- # Append to generated sequence
207
  generated_ids = np.concatenate([generated_ids, next_token_id], axis=1)
208
 
209
- # Check for EOS token (2 is typically EOS for Florence-2)
210
  if next_token_id[0, 0] == 2:
211
  break
212
 
213
- # Decode output
214
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
215
 
216
- # Post-process based on task
217
  try:
218
  result = processor.post_process_generation(generated_text, task, image.size)
219
- except Exception as e:
220
  result = {task: generated_text}
221
 
222
  return result
223
 
224
  def analyze_image(image, task_type):
225
- """Main analysis function for Gradio"""
226
  if image is None:
227
  return "Please upload an image first."
228
 
 
 
 
 
 
 
 
229
  try:
230
  if isinstance(image, np.ndarray):
231
  image = Image.fromarray(image)
@@ -248,99 +229,85 @@ def analyze_image(image, task_type):
248
 
249
  except Exception as e:
250
  import traceback
251
- error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
252
- return error_msg
253
 
254
  def create_prompt_from_analysis(image, analysis_result, prompt_type):
255
- """Convert analysis to AI generation prompts"""
 
 
 
256
  try:
257
- if isinstance(analysis_result, str):
258
- try:
259
- analysis = json.loads(analysis_result)
260
- except:
261
- analysis = {"description": analysis_result}
262
- else:
263
- analysis = analysis_result
264
 
 
265
  if isinstance(analysis, dict):
266
- if "<MORE_DETAILED_CAPTION>" in analysis:
267
- description = analysis["<MORE_DETAILED_CAPTION>"]
268
- elif "<CAPTION>" in analysis:
269
- description = analysis["<CAPTION>"]
270
- elif "<OCR>" in analysis:
271
- description = analysis["<OCR>"]
272
- else:
273
- description = str(analysis)
274
  else:
275
  description = str(analysis)
276
 
 
277
  if prompt_type == "Midjourney":
278
- prompt = f"""Midjourney Prompt:
279
  A highly detailed photograph of {description}, cinematic lighting, 8k resolution, sharp focus, professional photography, trending on ArtStation --ar 16:9 --v 6.0"""
280
 
281
  elif prompt_type == "Stable Diffusion":
282
- prompt = f"""Positive Prompt:
283
  {description}, masterpiece, best quality, highly detailed, 8k, sharp focus, professional photography, cinematic lighting, vibrant colors
284
 
285
  Negative Prompt:
286
  low quality, blurry, distorted, deformed, ugly, duplicate, watermark, signature, text, cropped, worst quality, jpeg artifacts"""
287
 
288
  elif prompt_type == "DALL-E":
289
- prompt = f"""DALL-E Prompt:
290
  A professional, high-quality image showing {description}. The image should be photorealistic with excellent composition, lighting, and detail. Suitable for commercial use."""
291
 
292
  elif prompt_type == "Master Creator Prompt":
293
- prompt = f"""MASTER PROMPT ANALYSIS
 
294
  ========================
295
 
296
- SOURCE ANALYSIS:
297
- {description}
298
 
299
  TECHNICAL BREAKDOWN:
300
- - Subject Matter: [Identify main subjects]
301
- - Composition: [Rule of thirds, symmetry, framing]
302
- - Lighting: [Natural/artificial, direction, quality]
303
- - Color Palette: [Dominant colors, contrast]
304
- - Style/Genre: [Photographic style]
305
- - Mood/Atmosphere: [Emotional tone]
306
- - Technical Aspects: [Camera angle, lens choice]
307
 
308
  RECREATION GUIDE:
309
- To recreate: {description}
310
 
311
- KEYWORDS:
312
- {', '.join(str(description).split()[:20])}, masterpiece, detailed, professional
313
 
314
  VARIATIONS:
315
  1. Golden hour lighting
316
- 2. Black and white
317
  3. Dramatic shadows
318
- 4. Bird's eye view
319
- 5. Cinematic teal/orange grading"""
320
-
321
- return prompt
322
 
323
  except Exception as e:
324
  return f"Error creating prompt: {str(e)}"
325
 
326
- # Initialize
327
- print("🚀 Initializing Florence-2 ONNX Space...")
328
- try:
329
- import transformers
330
- print(f"Transformers version: {transformers.__version__}")
331
- print(f"ONNX Runtime version: {ort.__version__}")
332
- except ImportError as e:
333
- print(f"Import error: {e}")
334
-
335
- init_models()
336
-
337
  # Gradio Interface
 
 
338
  with gr.Blocks(title="Florence-2 Vision Analyzer") as demo:
339
  gr.Markdown("""
340
  # 🎨 Florence-2 Vision Analyzer & Prompt Generator
341
- ### Powered by ONNX Runtime (FP16) - CPU Mode
342
 
343
  Upload an image to analyze it and generate AI-ready prompts!
 
344
  """)
345
 
346
  with gr.Row():
@@ -357,7 +324,7 @@ with gr.Blocks(title="Florence-2 Vision Analyzer") as demo:
357
  analysis_output = gr.Textbox(
358
  label="Analysis Result",
359
  lines=10,
360
- placeholder="Analysis will appear here..."
361
  )
362
 
363
  gr.Markdown("---")
@@ -390,14 +357,6 @@ with gr.Blocks(title="Florence-2 Vision Analyzer") as demo:
390
  outputs=prompt_output,
391
  api_visibility="public"
392
  )
393
-
394
- gr.Examples(
395
- examples=[
396
- [{"path": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"}],
397
- ],
398
- inputs=input_image,
399
- label="Try Example Image"
400
- )
401
 
402
  if __name__ == "__main__":
403
  demo.launch(
 
3
  import gradio as gr
4
  import numpy as np
5
  from PIL import Image
 
 
 
 
 
6
  import warnings
7
  import logging
8
+ import json
9
 
10
+ # Suppress warnings immediately
11
  warnings.filterwarnings("ignore")
12
  logging.getLogger("transformers").setLevel(logging.ERROR)
13
 
14
  # Configuration
15
  MODEL_REPO = "onnx-community/Florence-2-base"
 
 
 
16
  ONNX_FILES = {
17
  "vision_encoder": "vision_encoder_fp16.onnx",
18
  "embed_tokens": "embed_tokens_fp16.onnx",
19
  "encoder_model": "encoder_model_fp16.onnx",
20
+ "decoder_model": "decoder_model_fp16.onnx",
21
+ "decoder_with_past_model": "decoder_with_past_model_fp16.onnx"
22
  }
23
 
24
+ # Global variables - will be initialized lazily
25
  sessions = {}
26
  processor = None
27
  tokenizer = None
28
+ ort = None # Will import lazily
29
 
30
  def download_models():
31
  """Download ONNX models from HuggingFace Hub"""
32
+ from huggingface_hub import hf_hub_download
 
33
 
34
+ print("📥 Downloading ONNX models (FP16)...")
35
  os.makedirs("./models/onnx", exist_ok=True)
36
 
37
  for name, filename in ONNX_FILES.items():
 
42
  local_dir="./models",
43
  local_dir_use_symlinks=False
44
  )
 
45
  size_mb = os.path.getsize(path) / (1024 * 1024)
46
  print(f" ✓ {name}: {size_mb:.1f}MB")
47
 
48
+ print("✅ All models downloaded!")
49
 
50
  def init_models():
51
  """Initialize ONNX Runtime sessions"""
52
+ global sessions, processor, tokenizer, ort
53
 
54
+ # Lazy import onnxruntime to avoid build issues
55
+ import onnxruntime as ort_module
56
+ ort = ort_module
57
+
58
+ # Lazy import transformers
59
+ from transformers import AutoProcessor, AutoTokenizer, BartTokenizerFast
60
+ import transformers
61
+
62
+ print(f"Transformers version: {transformers.__version__}")
63
+ print(f"ONNX Runtime version: {ort.__version__}")
64
+
65
+ # Check if models exist
66
  if not all(os.path.exists(f"./models/onnx/{f}") for f in ONNX_FILES.values()):
67
  download_models()
68
  else:
69
  print("✅ Models already cached")
70
 
71
+ # CPU-only providers
72
  providers = ['CPUExecutionProvider']
73
+ print(f"Using providers: {providers}")
 
 
 
 
 
 
 
 
 
74
 
75
  # Load processor and tokenizer
76
  print("📥 Loading processor and tokenizer...")
 
82
  use_fast=False
83
  )
84
  except Exception as e:
85
+ print(f"Error with use_fast=False: {e}")
86
  processor = AutoProcessor.from_pretrained(
87
  "microsoft/Florence-2-base",
88
  trust_remote_code=True
 
105
 
106
  sess_options = ort.SessionOptions()
107
  sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
 
 
108
 
109
  try:
110
  sessions[name] = ort.InferenceSession(
 
119
 
120
  print("✅ All models loaded successfully!")
121
 
122
+ def generate_caption(image, task="<MORE_DETAILED_CAPTION>", max_new_tokens=256):
123
+ """Generate caption using Florence-2 ONNX models"""
124
 
125
+ # Prepare inputs
126
  inputs = processor(text=task, images=image, return_tensors="np")
 
 
127
  batch_size = 1
128
 
129
+ # 1. Vision Encoder
130
  pixel_values = inputs["pixel_values"].astype(np.float16)
131
  vision_outputs = sessions["vision_encoder"].run(None, {"pixel_values": pixel_values})
132
  image_features = vision_outputs[0]
133
 
134
+ # 2. Embed Tokens
135
  input_ids = inputs["input_ids"].astype(np.int64)
136
  embed_outputs = sessions["embed_tokens"].run(None, {"input_ids": input_ids})
137
  text_embeds = embed_outputs[0]
138
 
139
+ # 3. Concatenate for encoder
140
  combined_embeds = np.concatenate([image_features, text_embeds], axis=1)
 
 
141
  vision_seq_len = image_features.shape[1]
142
  text_seq_len = text_embeds.shape[1]
143
  combined_seq_len = vision_seq_len + text_seq_len
 
144
  encoder_attention_mask = np.ones((batch_size, combined_seq_len), dtype=np.int64)
145
 
146
+ # 4. Encoder
147
  encoder_outputs = sessions["encoder_model"].run(None, {
148
  "inputs_embeds": combined_embeds.astype(np.float16),
149
  "attention_mask": encoder_attention_mask
150
  })
151
  encoder_hidden_states = encoder_outputs[0]
152
 
153
+ # 5. Generation
 
154
  generated_ids = input_ids.copy()
155
+ use_past = False
156
 
157
  for i in range(max_new_tokens):
158
+ if not use_past:
159
+ # First step - use decoder_model
160
  decoder_inputs = {
161
+ "input_ids": generated_ids,
162
  "encoder_hidden_states": encoder_hidden_states.astype(np.float16),
163
  "encoder_attention_mask": encoder_attention_mask.astype(np.int64)
164
  }
 
165
  decoder_outputs = sessions["decoder_model"].run(None, decoder_inputs)
166
  logits = decoder_outputs[0]
167
+ use_past = True # Switch to past model for next iteration
 
 
 
168
  else:
169
+ # Subsequent steps - use decoder_with_past_model
 
170
  decoder_inputs = {
171
+ "input_ids": generated_ids[:, -1:],
172
  "encoder_hidden_states": encoder_hidden_states.astype(np.float16),
173
+ "encoder_attention_mask": encoder_attention_mask.astype(np.int64)
 
174
  }
 
175
  decoder_outputs = sessions["decoder_with_past_model"].run(None, decoder_inputs)
176
  logits = decoder_outputs[0]
 
 
 
 
177
 
178
+ # Greedy decoding
179
  next_token_logits = logits[:, -1, :]
180
  next_token_id = np.argmax(next_token_logits, axis=-1, keepdims=True)
 
 
181
  generated_ids = np.concatenate([generated_ids, next_token_id], axis=1)
182
 
183
+ # Check for EOS (token 2)
184
  if next_token_id[0, 0] == 2:
185
  break
186
 
187
+ # Decode
188
  generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
189
 
190
+ # Post-process
191
  try:
192
  result = processor.post_process_generation(generated_text, task, image.size)
193
+ except:
194
  result = {task: generated_text}
195
 
196
  return result
197
 
198
  def analyze_image(image, task_type):
199
+ """Main analysis function"""
200
  if image is None:
201
  return "Please upload an image first."
202
 
203
+ # Initialize models on first use if not already done
204
+ if not sessions:
205
+ try:
206
+ init_models()
207
+ except Exception as e:
208
+ return f"Initialization error: {str(e)}"
209
+
210
  try:
211
  if isinstance(image, np.ndarray):
212
  image = Image.fromarray(image)
 
229
 
230
  except Exception as e:
231
  import traceback
232
+ return f"Error: {str(e)}\n\n{traceback.format_exc()}"
 
233
 
234
  def create_prompt_from_analysis(image, analysis_result, prompt_type):
235
+ """Convert analysis to AI prompts"""
236
+ if not analysis_result or analysis_result == "Please upload an image first.":
237
+ return "Please analyze an image first."
238
+
239
  try:
240
+ # Parse analysis
241
+ try:
242
+ analysis = json.loads(analysis_result)
243
+ except:
244
+ analysis = {"description": analysis_result}
 
 
245
 
246
+ # Extract description
247
  if isinstance(analysis, dict):
248
+ description = (analysis.get("<MORE_DETAILED_CAPTION>") or
249
+ analysis.get("<CAPTION>") or
250
+ analysis.get("<OCR>") or
251
+ str(analysis))
 
 
 
 
252
  else:
253
  description = str(analysis)
254
 
255
+ # Generate prompts
256
  if prompt_type == "Midjourney":
257
+ return f"""Midjourney Prompt:
258
  A highly detailed photograph of {description}, cinematic lighting, 8k resolution, sharp focus, professional photography, trending on ArtStation --ar 16:9 --v 6.0"""
259
 
260
  elif prompt_type == "Stable Diffusion":
261
+ return f"""Positive Prompt:
262
  {description}, masterpiece, best quality, highly detailed, 8k, sharp focus, professional photography, cinematic lighting, vibrant colors
263
 
264
  Negative Prompt:
265
  low quality, blurry, distorted, deformed, ugly, duplicate, watermark, signature, text, cropped, worst quality, jpeg artifacts"""
266
 
267
  elif prompt_type == "DALL-E":
268
+ return f"""DALL-E Prompt:
269
  A professional, high-quality image showing {description}. The image should be photorealistic with excellent composition, lighting, and detail. Suitable for commercial use."""
270
 
271
  elif prompt_type == "Master Creator Prompt":
272
+ keywords = ', '.join(str(description).split()[:20])
273
+ return f"""MASTER PROMPT ANALYSIS
274
  ========================
275
 
276
+ SOURCE: {description}
 
277
 
278
  TECHNICAL BREAKDOWN:
279
+ - Subject Matter: Main subjects identified
280
+ - Composition: Rule of thirds, symmetry, framing
281
+ - Lighting: Natural/artificial, direction, quality
282
+ - Color Palette: Dominant colors, contrast
283
+ - Style: Photographic style and genre
284
+ - Mood: Emotional tone and atmosphere
 
285
 
286
  RECREATION GUIDE:
287
+ Focus on: {description}
288
 
289
+ KEYWORDS: {keywords}, masterpiece, detailed, professional
 
290
 
291
  VARIATIONS:
292
  1. Golden hour lighting
293
+ 2. Black and white conversion
294
  3. Dramatic shadows
295
+ 4. Bird's eye view perspective
296
+ 5. Cinematic color grading"""
 
 
297
 
298
  except Exception as e:
299
  return f"Error creating prompt: {str(e)}"
300
 
 
 
 
 
 
 
 
 
 
 
 
301
  # Gradio Interface
302
+ print("🚀 Starting Gradio app...")
303
+
304
  with gr.Blocks(title="Florence-2 Vision Analyzer") as demo:
305
  gr.Markdown("""
306
  # 🎨 Florence-2 Vision Analyzer & Prompt Generator
307
+ ### Powered by ONNX Runtime (FP16)
308
 
309
  Upload an image to analyze it and generate AI-ready prompts!
310
+ *Models will download on first use (~550MB)*
311
  """)
312
 
313
  with gr.Row():
 
324
  analysis_output = gr.Textbox(
325
  label="Analysis Result",
326
  lines=10,
327
+ placeholder="Click 'Analyze Image' to process..."
328
  )
329
 
330
  gr.Markdown("---")
 
357
  outputs=prompt_output,
358
  api_visibility="public"
359
  )
 
 
 
 
 
 
 
 
360
 
361
  if __name__ == "__main__":
362
  demo.launch(