Mandour commited on
Commit
a8207bc
·
1 Parent(s): 0e9d4a5

upload initial files

Browse files
Files changed (5) hide show
  1. .gitignore +12 -0
  2. README.md +14 -0
  3. app.py +531 -0
  4. models.py +290 -0
  5. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore Python virtual environments
2
+ /gp-env/
3
+ /demo-env/
4
+ /hf_cache/
5
+
6
+ # Ignore environment variable files
7
+ .env
8
+ .env.example
9
+
10
+ # Ignore __pycache__
11
+ __pycache__/
12
+ *.pyc
README.md CHANGED
@@ -12,3 +12,17 @@ short_description: Attribute Value Extraction Demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ ## Environment Variables
17
+
18
+ Set the following environment variables in your Hugging Face Space (Settings → Secrets and environment variables):
19
+
20
+ - `ROBERTA_TOKEN`: Hugging Face token with access to the Roberta model weights.
21
+ - `MERGER_MODEL_TOKEN`: Hugging Face token with access to the Merger model weights.
22
+
23
+ If running locally, you can create a `.env` file with these variables:
24
+
25
+ ```
26
+ ROBERTA_TOKEN=your_hf_token_here
27
+ MERGER_MODEL_TOKEN=your_hf_token_here
28
+ ```
app.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import json
4
+ import time
5
+ from typing import Tuple
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import transforms
10
+ import os
11
+ from models import (
12
+ get_device, get_tokenizers, get_image_processor,
13
+ load_merger_model, get_predicated_values
14
+ )
15
+
16
+ # Load environment variables (optional for local dev; Spaces use web UI for env vars)
17
+ if os.path.exists('.env'):
18
+ from dotenv import load_dotenv
19
+ load_dotenv()
20
+
21
+ # Global constants
22
+ ATTRIBUTES_LIST = ['sleeve', 'color', 'type', 'pattern',
23
+ 'material', 'style', 'neck', 'gender', 'brand']
24
+ MAX_SEQ_LENGTH = 256
25
+ DECODER_MAX_SEQ_LENGTH = 64
26
+
27
+ # Global variables for model components
28
+ MODEL_COMPONENTS = None
29
+ MODEL_LOADED = False
30
+
31
+ def initialize_model_and_tokenizers():
32
+ """Initialize model and tokenizers once"""
33
+ global MODEL_COMPONENTS, MODEL_LOADED
34
+
35
+ if MODEL_LOADED and MODEL_COMPONENTS:
36
+ return MODEL_COMPONENTS
37
+
38
+ try:
39
+ print("🔄 Loading AI model components...")
40
+ device = get_device()
41
+ bert_tokenizer, roberta_tokenizer = get_tokenizers()
42
+ image_processor = get_image_processor()
43
+ model = load_merger_model(bert_tokenizer, device)
44
+
45
+ MODEL_COMPONENTS = {
46
+ 'model': model,
47
+ 'bert_tokenizer': bert_tokenizer,
48
+ 'roberta_tokenizer': roberta_tokenizer,
49
+ 'image_processor': image_processor,
50
+ 'device': device
51
+ }
52
+ MODEL_LOADED = True
53
+ print("✅ Model loaded successfully!")
54
+ return MODEL_COMPONENTS
55
+ except Exception as e:
56
+ print(f"❌ Failed to load model: {str(e)}")
57
+ raise e
58
+
59
+ def validate_inputs(image, text_input: str, category: str) -> Tuple[bool, str]:
60
+ """Validate that all inputs are provided"""
61
+ if image is None:
62
+ return False, "❌ Please upload an image file"
63
+
64
+ if not text_input or text_input.strip() == "":
65
+ return False, "❌ Please provide a product description"
66
+
67
+ if not category:
68
+ return False, "❌ Please select a product category"
69
+
70
+ return True, "✅ Inputs validated successfully"
71
+
72
+ def resize_image_for_display(image: Image.Image, target_size=(512, 512)) -> Image.Image:
73
+ """Resize image for consistent display"""
74
+ if image.mode != 'RGBA':
75
+ image = image.convert('RGBA')
76
+
77
+ # Compute new size preserving aspect ratio
78
+ orig_w, orig_h = image.size
79
+ max_w, max_h = target_size
80
+
81
+ # Determine scale factor
82
+ scale = min(max_w / orig_w, max_h / orig_h)
83
+ new_w = int(orig_w * scale)
84
+ new_h = int(orig_h * scale)
85
+
86
+ # Resize with high-quality resampling
87
+ resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
88
+ return resized
89
+
90
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
91
+ """Preprocess image for model input"""
92
+ if image.mode != 'RGBA':
93
+ image = image.convert('RGBA')
94
+
95
+ # Apply transformations
96
+ image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)
97
+ image_tensor = image_tensor.unsqueeze(0)
98
+ return image_tensor
99
+
100
+ def run_inference(image_tensor: torch.Tensor, description: str, category: str, model_components: dict) -> dict:
101
+ """Run model inference using get_predicated_values API"""
102
+ model = model_components['model']
103
+ bert_tokenizer = model_components['bert_tokenizer']
104
+ roberta_tokenizer = model_components['roberta_tokenizer']
105
+ image_processor = model_components['image_processor']
106
+ device = model_components['device']
107
+
108
+ # Convert tensor to PIL Image for processor
109
+ pil_img = transforms.ToPILImage()(image_tensor.squeeze(0).cpu())
110
+ start_time = time.time()
111
+
112
+ results = get_predicated_values(
113
+ model, category, pil_img, description,
114
+ image_processor, bert_tokenizer, roberta_tokenizer, device
115
+ )
116
+
117
+ end_time = time.time()
118
+
119
+ # Format for UI
120
+ total_attributes = len([a for a in results if a["value"] and a["value"] != "N/A"])
121
+ avg_confidence = np.mean([a["confidence"] for a in results if a["value"]
122
+ and a["value"] != "N/A"]) if total_attributes > 0 else 0
123
+
124
+ return {
125
+ "attributes": results,
126
+ "total_attributes": total_attributes,
127
+ "avg_confidence": avg_confidence,
128
+ "processing_time": end_time - start_time
129
+ }
130
+
131
+ def get_confidence_color(confidence: float) -> str:
132
+ """Get color based on confidence level"""
133
+ if confidence >= 0.8:
134
+ return "#28a745" # Green
135
+ elif confidence >= 0.6:
136
+ return "#ffc107" # Yellow
137
+ else:
138
+ return "#dc3545" # Red
139
+
140
+ def format_results_html(results: dict) -> str:
141
+ """Format results as HTML for display"""
142
+ if not results or results["total_attributes"] == 0:
143
+ return """
144
+ <div style="padding: 20px; text-align: center; background-color: #fff3cd; border-radius: 10px; border: 1px solid #ffeaa7;">
145
+ <h3 style="color: #856404; margin: 0;">🔍 No attributes extracted</h3>
146
+ <p style="color: #856404; margin: 10px 0 0 0;">Try with a different image or more detailed description.</p>
147
+ </div>
148
+ """
149
+
150
+ html = """
151
+ <div style="padding: 20px;">
152
+ <h3 style="color: #333; margin-bottom: 20px; font-size: 1.5em;">📊 Extracted Attributes</h3>
153
+ """
154
+
155
+ for attr in results["attributes"]:
156
+ if attr["value"] != "N/A":
157
+ confidence = attr["confidence"]
158
+ color = get_confidence_color(confidence)
159
+
160
+ html += f"""
161
+ <div style="
162
+ background: white;
163
+ padding: 15px;
164
+ margin-bottom: 10px;
165
+ border-radius: 10px;
166
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
167
+ border-left: 4px solid #667eea;
168
+ display: flex;
169
+ justify-content: space-between;
170
+ align-items: center;
171
+ ">
172
+ <div>
173
+ <strong style="color: #333; font-size: 1.1em;">{attr["name"].title()}</strong>
174
+ <span style="color: #666; margin-left: 10px;">{attr["value"]}</span>
175
+ </div>
176
+ <div style="
177
+ background-color: {color};
178
+ color: white;
179
+ padding: 4px 8px;
180
+ border-radius: 12px;
181
+ font-size: 0.8em;
182
+ font-weight: bold;
183
+ ">
184
+ {confidence:.1%}
185
+ </div>
186
+ </div>
187
+ """
188
+
189
+ # Add summary statistics
190
+ html += f"""
191
+ <div style="
192
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
193
+ color: white;
194
+ padding: 15px;
195
+ border-radius: 10px;
196
+ margin-top: 20px;
197
+ text-align: center;
198
+ ">
199
+ <h4 style="margin: 0;">📈 Summary</h4>
200
+ <p style="margin: 10px 0 0 0;">
201
+ <strong>{results["total_attributes"]}</strong> attributes extracted |
202
+ <strong>{results["avg_confidence"]:.1%}</strong> avg confidence |
203
+ <strong>{results["processing_time"]:.2f}s</strong> processing time
204
+ </p>
205
+ </div>
206
+ </div>
207
+ """
208
+
209
+ return html
210
+
211
+ def create_download_files(results: dict) -> Tuple[str, str]:
212
+ """Create JSON and CSV files for download"""
213
+ if not results:
214
+ return None, None
215
+
216
+ # JSON file
217
+ json_content = json.dumps(results, indent=2)
218
+ json_file = "attributes.json"
219
+ with open(json_file, "w") as f:
220
+ f.write(json_content)
221
+
222
+ # CSV file
223
+ df = pd.DataFrame(results["attributes"])
224
+ csv_file = "attributes.csv"
225
+ df.to_csv(csv_file, index=False)
226
+
227
+ return json_file, csv_file
228
+
229
+ def process_inputs(image, category, description, progress=gr.Progress()):
230
+ """Main processing function"""
231
+ global MODEL_COMPONENTS
232
+
233
+ # Initialize model if needed
234
+ if not MODEL_LOADED:
235
+ progress(0.1, desc="Loading AI model...")
236
+ try:
237
+ MODEL_COMPONENTS = initialize_model_and_tokenizers()
238
+ except Exception as e:
239
+ error_msg = f"❌ Failed to load model: {str(e)}"
240
+ return None, error_msg, None, None, None
241
+
242
+ # Validate inputs
243
+ is_valid, validation_message = validate_inputs(image, description, category)
244
+ if not is_valid:
245
+ return None, validation_message, None, None, None
246
+
247
+ try:
248
+ # Step 1: Image preprocessing
249
+ progress(0.3, desc="📸 Preprocessing image...")
250
+ resized_image = resize_image_for_display(image, (512, 512))
251
+ image_tensor = preprocess_image(resized_image)
252
+
253
+ # Step 2: Model inference
254
+ progress(0.7, desc="🧠 Running AI inference...")
255
+ results = run_inference(image_tensor, description, category, MODEL_COMPONENTS)
256
+
257
+ # Step 3: Format results
258
+ progress(0.9, desc="📊 Formatting results...")
259
+ results_html = format_results_html(results)
260
+
261
+ # Create download files
262
+ json_file, csv_file = create_download_files(results)
263
+
264
+ progress(1.0, desc="✅ Processing complete!")
265
+
266
+ success_msg = f"🎉 Successfully extracted {results['total_attributes']} attributes!"
267
+ return resized_image, success_msg, results_html, json_file, csv_file
268
+
269
+ except Exception as e:
270
+ error_msg = f"❌ Processing failed: {str(e)}"
271
+ return None, error_msg, None, None, None
272
+
273
+ # Custom CSS for styling
274
+ custom_css = """
275
+ /* Global styling */
276
+ .gradio-container {
277
+ max-width: 1200px !important;
278
+ margin: auto !important;
279
+ }
280
+
281
+ /* Header styling */
282
+ .header-text {
283
+ text-align: center;
284
+ color: #333;
285
+ margin-bottom: 30px;
286
+ }
287
+
288
+ /* Input section styling */
289
+ .input-section {
290
+ background: #f8f9fa;
291
+ padding: 20px;
292
+ border-radius: 15px;
293
+ margin-bottom: 20px;
294
+ }
295
+
296
+ /* Button styling */
297
+ .primary-button {
298
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
299
+ border: none !important;
300
+ color: white !important;
301
+ font-weight: bold !important;
302
+ padding: 12px 24px !important;
303
+ border-radius: 25px !important;
304
+ font-size: 16px !important;
305
+ }
306
+
307
+ /* Results section styling */
308
+ .results-section {
309
+ background: white;
310
+ padding: 20px;
311
+ border-radius: 15px;
312
+ box-shadow: 0 4px 15px rgba(0,0,0,0.1);
313
+ }
314
+
315
+ /* Status message styling */
316
+ .status-positive {
317
+ color: #28a745;
318
+ font-weight: bold;
319
+ padding: 10px;
320
+ background-color: #d4edda;
321
+ border-radius: 8px;
322
+ border: 1px solid #c3e6cb;
323
+ }
324
+
325
+ .status-negative {
326
+ color: #721c24;
327
+ font-weight: bold;
328
+ padding: 10px;
329
+ background-color: #f8d7da;
330
+ border-radius: 8px;
331
+ border: 1px solid #f5c6cb;
332
+ }
333
+
334
+ /* Info box styling */
335
+ .info-box {
336
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
337
+ color: white;
338
+ padding: 20px;
339
+ border-radius: 15px;
340
+ margin: 20px 0;
341
+ }
342
+
343
+ /* Tips styling */
344
+ .tips-section {
345
+ background: #e9ecef;
346
+ padding: 15px;
347
+ border-radius: 10px;
348
+ margin-top: 20px;
349
+ }
350
+ """
351
+
352
+ # Create Gradio interface
353
+ def create_interface():
354
+ """Create the main Gradio interface"""
355
+
356
+ with gr.Blocks(css=custom_css, title="AI Attribute Extractor", theme=gr.themes.Soft()) as demo:
357
+
358
+ # Header
359
+ gr.HTML("""
360
+ <div class="header-text">
361
+ <h1>🔍 AI Attribute Extractor</h1>
362
+ <p style="font-size: 1.1em; color: #666;">Upload an image and provide text to extract detailed attributes using AI</p>
363
+ </div>
364
+ """)
365
+
366
+ with gr.Row():
367
+ # Left column - Input section
368
+ with gr.Column(scale=1):
369
+ gr.HTML("<h2>📤 Input Section</h2>")
370
+
371
+ # Image upload
372
+ image_input = gr.Image(
373
+ label="Upload Product Image",
374
+ type="pil",
375
+ height=300,
376
+ elem_classes=["input-section"]
377
+ )
378
+
379
+ # Category selection
380
+ category_input = gr.Dropdown(
381
+ choices=["clothing", "bags", "shoes", "accessories"],
382
+ label="Product Category",
383
+ value="clothing",
384
+ elem_classes=["input-section"]
385
+ )
386
+
387
+ # Text description
388
+ text_input = gr.Textbox(
389
+ label="Product Description",
390
+ placeholder="Describe the product in detail...",
391
+ lines=4,
392
+ elem_classes=["input-section"]
393
+ )
394
+
395
+ # Process button
396
+ process_btn = gr.Button(
397
+ "🚀 Extract Attributes",
398
+ variant="primary",
399
+ size="lg",
400
+ elem_classes=["primary-button"]
401
+ )
402
+
403
+ # Status message
404
+ status_msg = gr.HTML(label="Status")
405
+
406
+ # Right column - Results section
407
+ with gr.Column(scale=1):
408
+ gr.HTML("<h2>📊 Results Section</h2>")
409
+
410
+ # Processed image display
411
+ processed_image = gr.Image(
412
+ label="Processed Image",
413
+ height=300,
414
+ elem_classes=["results-section"]
415
+ )
416
+
417
+ # Results display
418
+ results_html = gr.HTML(
419
+ label="Extracted Attributes",
420
+ elem_classes=["results-section"]
421
+ )
422
+
423
+ # Download buttons
424
+ with gr.Row():
425
+ json_download = gr.File(
426
+ label="📄 Download JSON",
427
+ visible=False
428
+ )
429
+ csv_download = gr.File(
430
+ label="📊 Download CSV",
431
+ visible=False
432
+ )
433
+
434
+ # Info section
435
+ with gr.Row():
436
+ with gr.Column():
437
+ gr.HTML("""
438
+ <div class="info-box">
439
+ <h3>ℹ️ About This Tool</h3>
440
+ <p>This AI-powered tool extracts product attributes from images and text descriptions using:</p>
441
+ <ul>
442
+ <li><strong>🖼️ Vision Transformer (DeiT)</strong> for image analysis</li>
443
+ <li><strong>🔤 BERT & RoBERTa</strong> for text understanding</li>
444
+ <li><strong>🧠 Hierarchical Fusion</strong> for multimodal learning</li>
445
+ <li><strong>⚡ LoRA/DoRA</strong> for efficient fine-tuning</li>
446
+ </ul>
447
+ </div>
448
+ """)
449
+
450
+ with gr.Column():
451
+ gr.HTML(f"""
452
+ <div class="tips-section">
453
+ <h3>🎯 Tips for Better Results</h3>
454
+ <ul>
455
+ <li>Use clear, well-lit images</li>
456
+ <li>Provide detailed descriptions</li>
457
+ <li>Include specific product details</li>
458
+ <li>Avoid blurry or low-quality images</li>
459
+ </ul>
460
+ <h4>Supported Attributes:</h4>
461
+ <p>{', '.join([attr.title() for attr in ATTRIBUTES_LIST])}</p>
462
+ </div>
463
+ """)
464
+
465
+ # Event handlers
466
+ def update_status(message: str, is_error: bool = False):
467
+ """Update status message with styling"""
468
+ class_name = "status-negative" if is_error else "status-positive"
469
+ return f'<div class="{class_name}">{message}</div>'
470
+
471
+ def process_and_update(image, category, description):
472
+ """Process inputs and update all outputs"""
473
+ processed_img, status, results, json_file, csv_file = process_inputs(
474
+ image, category, description
475
+ )
476
+
477
+ # Update status with styling
478
+ is_error = status.startswith("❌")
479
+ styled_status = update_status(status, is_error)
480
+
481
+ # Show download buttons if successful
482
+ json_visible = json_file is not None
483
+ csv_visible = csv_file is not None
484
+
485
+ return (
486
+ processed_img,
487
+ styled_status,
488
+ results,
489
+ gr.update(value=json_file, visible=json_visible),
490
+ gr.update(value=csv_file, visible=csv_visible)
491
+ )
492
+
493
+ # Connect the process button
494
+ process_btn.click(
495
+ fn=process_and_update,
496
+ inputs=[image_input, category_input, text_input],
497
+ outputs=[processed_image, status_msg, results_html, json_download, csv_download]
498
+ )
499
+
500
+ # Example inputs
501
+ gr.Examples(
502
+ examples=[
503
+ [
504
+ "https://example.com/sample_image.jpg", # You can replace with actual sample images
505
+ "clothing",
506
+ "A stylish red cotton t-shirt with short sleeves and a round neck, perfect for casual wear."
507
+ ]
508
+ ],
509
+ inputs=[image_input, category_input, text_input],
510
+ label="Try these examples"
511
+ )
512
+
513
+ return demo
514
+
515
+ # Launch the app
516
+ if __name__ == "__main__":
517
+ # Initialize model on startup
518
+ print("Initializing AI Attribute Extractor...")
519
+
520
+ # Create and launch the interface
521
+ demo = create_interface()
522
+
523
+ # Launch configuration
524
+ demo.launch(
525
+ server_name="0.0.0.0", # For Hugging Face Spaces
526
+ server_port=7860, # Default port for Hugging Face Spaces
527
+ share=False, # Set to True for public sharing
528
+ debug=False, # Set to True for development
529
+ show_error=True, # Show error messages
530
+ quiet=False # Set to True to reduce logging
531
+ )
models.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (AutoProcessor,
2
+ RobertaConfig,
3
+ BertTokenizerFast,
4
+ RobertaTokenizerFast,
5
+ RobertaModel,
6
+ BlipForQuestionAnswering)
7
+ from huggingface_hub import hf_hub_download
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import os
13
+
14
+ # Load environment variables (optional for local dev; Spaces use web UI for env vars)
15
+ if os.path.exists('.env'):
16
+ from dotenv import load_dotenv
17
+ load_dotenv()
18
+
19
+ ATTRIBUTES_LIST = ['sleeve', 'type', 'pattern', 'material',
20
+ 'neck', 'color', 'style', 'brand', 'gender']
21
+
22
+ HF_CACHE_DIR = "./hf_cache"
23
+
24
+
25
+ def get_device():
26
+ return "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+
29
+ def get_tokenizers():
30
+ bert_tokenizer = BertTokenizerFast.from_pretrained(
31
+ "google-bert/bert-base-uncased", cache_dir=HF_CACHE_DIR)
32
+ roberta_tokenizer = RobertaTokenizerFast.from_pretrained(
33
+ "FacebookAI/roberta-base", cache_dir=HF_CACHE_DIR)
34
+ bert_tokenizer.add_special_tokens({'bos_token': '[DEC]'})
35
+ return bert_tokenizer, roberta_tokenizer
36
+
37
+
38
+ def get_image_processor():
39
+ return AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=HF_CACHE_DIR)
40
+
41
+
42
+ class AttentionModalityMerger(nn.Module):
43
+ def __init__(self, text_dim, image_dim):
44
+ super().__init__()
45
+ self.text_layer_norm = nn.LayerNorm(text_dim)
46
+ self.image_layer_norm = nn.LayerNorm(image_dim)
47
+ self.linear = nn.Linear(
48
+ in_features=image_dim + text_dim, out_features=1)
49
+ self.sigmoid = nn.Sigmoid()
50
+
51
+ def forward(self, text_embedds, image_features, attention_mask):
52
+ input_mask_expanded = attention_mask.unsqueeze(
53
+ -1).expand(text_embedds.size()).float()
54
+ text_embedds = input_mask_expanded * text_embedds
55
+ text_embedds = text_embedds.sum(dim=1)
56
+ text_embedds_norm = self.text_layer_norm(text_embedds)
57
+ image_features = image_features.sum(dim=1)
58
+ image_features_norm = self.image_layer_norm(image_features)
59
+ text_image_embedds = torch.cat(
60
+ [text_embedds_norm, image_features_norm], axis=-1)
61
+ gate_output = self.linear(text_image_embedds)
62
+ p_txt = self.sigmoid(gate_output)
63
+ p_img = 1 - p_txt
64
+ scaled_text = p_txt * text_embedds_norm
65
+ scaled_image = p_img * image_features_norm
66
+ final_output = torch.cat([scaled_text, scaled_image], dim=-1)
67
+ return final_output, p_txt, p_img
68
+
69
+
70
+ class RobertaTokenClassificationWithCRF(nn.Module):
71
+ def __init__(self, vocab_size, device, roberta_token=None):
72
+ if roberta_token is None:
73
+ roberta_token = os.getenv("ROBERTA_TOKEN")
74
+ super().__init__()
75
+ self.vocab_size = vocab_size
76
+ self.config = RobertaConfig()
77
+ self.roberta = RobertaModel.from_pretrained(
78
+ "FacebookAI/roberta-base", output_hidden_states=True, cache_dir=HF_CACHE_DIR)
79
+ self.freeze_layers()
80
+ self._loadTextWeights(device, roberta_token)
81
+
82
+ def _loadTextWeights(self, device, roberta_token):
83
+ repo_id = "LomaaZakaria/Roberta_Attribute_Value_Extraction_Model"
84
+ weights_file_name = "RobertaCRFWithNOAnswerClassifier_OnFashionGenData_2epochs.pth"
85
+ weights_file_path = hf_hub_download(
86
+ repo_id=repo_id, filename=weights_file_name, token=roberta_token, cache_dir=HF_CACHE_DIR)
87
+ state_dict = torch.load(
88
+ weights_file_path, weights_only=True, map_location=device)
89
+ text_model_state_dict = self.roberta.state_dict()
90
+ filtered_state_dict = {
91
+ k: v for k, v in state_dict.items()
92
+ if k in text_model_state_dict and v.shape == text_model_state_dict[k].shape
93
+ }
94
+ self.roberta.load_state_dict(filtered_state_dict, strict=False)
95
+
96
+ def freeze_layers(self):
97
+ self.roberta.embeddings.requires_grad_(False)
98
+ for layers in self.roberta.encoder.layer[:8]:
99
+ for p in layers.parameters():
100
+ p.requires_grad = False
101
+
102
+ def forward(self, token_ids, attention_mask):
103
+ outputs = self.roberta(input_ids=token_ids,
104
+ attention_mask=attention_mask)
105
+ last_hidden_state = outputs.hidden_states[-1]
106
+ return last_hidden_state
107
+
108
+
109
+ class ImageModel(nn.Module):
110
+ def __init__(self):
111
+ super(ImageModel, self).__init__()
112
+ self.vision_model = BlipForQuestionAnswering.from_pretrained(
113
+ "Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).vision_model
114
+ self._freezeLayers()
115
+
116
+ def _freezeLayers(self):
117
+ self.vision_model.embeddings.requires_grad_(False)
118
+ for layer in self.vision_model.encoder.layers[:8]:
119
+ for p in layer.parameters():
120
+ p.requires_grad = False
121
+
122
+ def forward(self, x):
123
+ return self.vision_model(x).last_hidden_state
124
+
125
+
126
+ class MergerModel(nn.Module):
127
+ def __init__(self, vocab_size, device, roberta_token=None):
128
+ if roberta_token is None:
129
+ roberta_token = os.getenv("ROBERTA_TOKEN")
130
+ super().__init__()
131
+ self.text_decoder = BlipForQuestionAnswering.from_pretrained(
132
+ "Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).text_decoder
133
+ self.text_encoder = RobertaTokenClassificationWithCRF(
134
+ vocab_size, device, roberta_token)
135
+ self.vision_model = ImageModel()
136
+ text_dim, image_dim = self.text_encoder.config.hidden_size, 768
137
+ self.attention_merger = AttentionModalityMerger(text_dim, image_dim)
138
+ self.linear = nn.Linear(in_features=text_dim +
139
+ image_dim, out_features=text_dim)
140
+
141
+ def forward(self, **inputs):
142
+ text_encoder = self.text_encoder(
143
+ token_ids=inputs['encoder_token_ids'], attention_mask=inputs['encoder_attention_mask'])
144
+ vision_encoder = self.vision_model(x=inputs['image'])
145
+ merger_output, p_txt, p_img = self.attention_merger(
146
+ text_encoder, vision_encoder, attention_mask=inputs['encoder_attention_mask'])
147
+ merger_output = merger_output.unsqueeze(1)
148
+ batch_size = vision_encoder.shape[0]
149
+ merger_output_mask = torch.ones(
150
+ (batch_size, 1), dtype=torch.long, device=vision_encoder.device)
151
+ merger_output_linear = self.linear(merger_output)
152
+ decoder_output = self.text_decoder(
153
+ input_ids=inputs['decoder_input_token_ids'],
154
+ attention_mask=inputs['decoder_input_attention_mask'],
155
+ encoder_hidden_states=merger_output_linear,
156
+ encoder_attention_mask=merger_output_mask,
157
+ return_dict=True,
158
+ return_logits=True
159
+ )
160
+ logits = decoder_output
161
+ return logits, p_txt, p_img
162
+
163
+
164
+ def load_merger_model(bert_tokenizer, device, model_token=None):
165
+ if model_token is None:
166
+ model_token = os.getenv("MERGER_MODEL_TOKEN")
167
+ vocab_size = len(bert_tokenizer)
168
+ model = MergerModel(vocab_size, device)
169
+ repo_id = "MohamedMosilhy/AttentionMergerModality"
170
+ weights_file_name = "Freezing_More_NewViTBlipAttentionMergerModality_4epochs_2e_5_withwarmup.pth"
171
+ weights_file_path = hf_hub_download(
172
+ repo_id=repo_id, filename=weights_file_name, token=model_token, cache_dir=HF_CACHE_DIR)
173
+ model.load_state_dict(torch.load(
174
+ weights_file_path, weights_only=True, map_location=device))
175
+ model.to(device)
176
+ model.eval()
177
+ return model
178
+
179
+
180
+ def model_generate(model, data, text_tokenizer, device, labels=None, max_generated_length=50, testing=False, return_confidence=False):
181
+ if labels is None:
182
+ labels = '[DEC]'
183
+ token_labels = text_tokenizer.convert_tokens_to_ids([labels])
184
+ else:
185
+ token_labels = text_tokenizer.convert_tokens_to_ids([labels])
186
+ model.eval()
187
+ confidences = []
188
+ for index in range(max_generated_length):
189
+ decoder_inputs = text_tokenizer(
190
+ text=labels, max_length=65, padding='max_length', add_special_tokens=False, return_tensors="pt")
191
+ decoder_data = {
192
+ "decoder_input_token_ids": decoder_inputs['input_ids'],
193
+ "decoder_input_attention_mask": decoder_inputs['attention_mask']
194
+ }
195
+ inputs = {
196
+ "image": data['image'].unsqueeze(0).to(device),
197
+ "encoder_token_ids": data['encoder_token_ids'].unsqueeze(0).to(device),
198
+ "encoder_attention_mask": data['encoder_attention_mask'].unsqueeze(0).to(device),
199
+ "decoder_input_token_ids": decoder_data['decoder_input_token_ids'].to(device),
200
+ "decoder_input_attention_mask": decoder_data['decoder_input_attention_mask'].to(device)
201
+ }
202
+ with torch.no_grad():
203
+ logits, _, _ = model(**inputs)
204
+ probs = F.softmax(logits, dim=-1)
205
+ predicated_label = torch.argmax(
206
+ probs[:, index, :], dim=-1).cpu().numpy()
207
+ # Get confidence for this token
208
+ confidence = float(
209
+ probs[0, index, predicated_label[0]].cpu().item())
210
+ confidences.append(confidence)
211
+ token_labels.append(predicated_label[0])
212
+ predicted_tokens = text_tokenizer.convert_ids_to_tokens(
213
+ predicated_label)
214
+ labels = text_tokenizer.decode(token_labels)
215
+ if predicted_tokens[0] == text_tokenizer.sep_token:
216
+ break
217
+ predicated_attribute_value = text_tokenizer.decode(token_labels)
218
+ if testing:
219
+ token_labels = np.array(token_labels)
220
+ dec_token_id = text_tokenizer.bos_token_id
221
+ token_labels = token_labels[token_labels != dec_token_id]
222
+ return token_labels
223
+ if return_confidence:
224
+ # Use the minimum confidence across the generated tokens as the attribute confidence
225
+ return predicated_attribute_value, min(confidences) if confidences else 0.0
226
+ return predicated_attribute_value
227
+
228
+
229
+ # Define which attributes are relevant for each category
230
+ CATEGORY_ATTRIBUTES = {
231
+ "clothing": ['sleeve', 'type', 'pattern', 'material', 'neck', 'color', 'style', 'brand', 'gender'],
232
+ "bags": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'],
233
+ "shoes": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'],
234
+ "accessories": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'],
235
+ }
236
+
237
+ def get_predicated_values(
238
+ model, category, img, desc, image_processor, bert_tokenizer, roberta_tokenizer, device, max_seq_length=256
239
+ ):
240
+ results = []
241
+
242
+ def _combined_with_CategoriesAttributes(desc, category, attribute):
243
+ return category + ' ' + attribute
244
+
245
+ def imageProcesser(img):
246
+ return image_processor(img)
247
+
248
+ def _tokenizeText(image, desc, category, attribute):
249
+ combined_desc = _combined_with_CategoriesAttributes(
250
+ desc, category, attribute)
251
+ image_inputs = imageProcesser(image)
252
+ text_encoder_inputs = roberta_tokenizer(
253
+ combined_desc,
254
+ desc,
255
+ max_length=max_seq_length,
256
+ padding='max_length',
257
+ return_tensors='np'
258
+ )
259
+ return image_inputs, text_encoder_inputs
260
+
261
+ # Normalize category to lower-case and pick attributes
262
+ category_key = str(category).strip().lower()
263
+ attributes = CATEGORY_ATTRIBUTES.get(category_key, CATEGORY_ATTRIBUTES["clothing"])
264
+
265
+ image = img
266
+ for attribute in attributes:
267
+ image_inputs, text_encoder_inputs = _tokenizeText(
268
+ image, desc, category, attribute)
269
+ image_data = torch.from_numpy(np.array(image_inputs['pixel_values']))
270
+ encoder_token_ids = torch.from_numpy(
271
+ np.array(text_encoder_inputs['input_ids']))
272
+ encoder_attn_mask = torch.from_numpy(
273
+ np.array(text_encoder_inputs['attention_mask']))
274
+ inputs = {
275
+ "image": image_data.squeeze(0),
276
+ "encoder_token_ids": encoder_token_ids.squeeze(0),
277
+ "encoder_attention_mask": encoder_attn_mask.squeeze(0),
278
+ }
279
+
280
+ predicated_value, confidence = model_generate(
281
+ model, inputs, text_tokenizer=bert_tokenizer, device=device, return_confidence=True
282
+ )
283
+ # Remove [DEC] and [SEP] tokens and strip whitespace
284
+ clean_value = predicated_value.replace('[DEC]', '').replace('[SEP]', '').strip()
285
+ if clean_value != 'not specified':
286
+ results.append(
287
+ {"name": attribute, "value": clean_value,
288
+ "confidence": float(confidence)}
289
+ )
290
+ return results
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=1.9.0
3
+ torchvision>=0.10.0
4
+ transformers>=4.20.0
5
+ huggingface-hub
6
+ python-dotenv
7
+ Pillow>=9.0.0
8
+ pandas>=1.3.0
9
+ numpy>=1.21.0