File size: 36,843 Bytes
177eb65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
"""
🎨 AI Image Editor Pro - Streamlit Version
=============================================
A private, self-hosted AI image editing tool using open-source models.
Runs on Hugging Face Spaces with Streamlit SDK.
Now with advanced Gemini-style instruction understanding!
"""

import os
import gc
import re
import torch
import numpy as np
import streamlit as st
from PIL import Image
from typing import Tuple, Optional, Dict, List
from io import BytesIO

# ============================================================================
# PAGE CONFIG (must be first Streamlit command)
# ============================================================================

st.set_page_config(
    page_title="🎨 AI Image Editor Pro",
    page_icon="🎨",
    layout="wide",
    initial_sidebar_state="expanded"
)

# ============================================================================
# CONFIGURATION
# ============================================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
INPAINT_MODEL = "runwayml/stable-diffusion-inpainting"
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"

# ============================================================================
# ADVANCED INSTRUCTION PARSER - GEMINI-STYLE
# ============================================================================

class GeminiStyleParser:
    """
    Advanced natural language parser that understands complex editing instructions
    like Google Gemini. Handles various phrasings, synonyms, and compound commands.
    """
    
    # Comprehensive action patterns with synonyms
    REMOVE_KEYWORDS = [
        "remove", "delete", "erase", "get rid of", "take out", "eliminate",
        "clear", "wipe", "clean up", "take away", "disappear", "vanish",
        "make disappear", "get away", "rid of", "cut out", "crop out",
        "hide", "discard", "throw away", "dispose", "extract", "pull out",
        "subtract", "minus", "without", "lose", "drop", "ditch", "nix",
        "scratch", "strike", "zap", "nuke", "kill", "destroy", "obliterate"
    ]
    
    REPLACE_KEYWORDS = [
        "replace", "swap", "switch", "substitute", "exchange", "trade",
        "put", "place", "add", "insert", "set", "change to", "turn into",
        "transform to", "convert to", "make it", "make this", "transform into",
        "morph into", "become", "evolve into", "shift to"
    ]
    
    CHANGE_KEYWORDS = [
        "change", "modify", "alter", "adjust", "edit", "transform",
        "convert", "turn", "make", "update", "recolor", "repaint",
        "tint", "color", "paint", "dye", "shade", "hue", "tone",
        "brighten", "darken", "lighten", "saturate", "desaturate"
    ]
    
    ADD_KEYWORDS = [
        "add", "insert", "put", "place", "include", "attach",
        "append", "introduce", "bring", "create", "generate",
        "draw", "paint", "render", "give", "apply", "overlay"
    ]
    
    ENHANCE_KEYWORDS = [
        "enhance", "improve", "beautify", "upgrade", "refine",
        "polish", "perfect", "optimize", "boost", "amplify",
        "sharpen", "clarify", "fix", "repair", "restore"
    ]
    
    # Prepositions and connectors
    PREPOSITIONS = [
        "with", "to", "into", "as", "by", "for", "from",
        "using", "via", "through", "in place of", "instead of"
    ]
    
    # Color mappings for better understanding
    COLORS = {
        "red": "vibrant red colored",
        "blue": "deep blue colored",
        "green": "lush green colored",
        "yellow": "bright yellow colored",
        "orange": "warm orange colored",
        "purple": "rich purple colored",
        "pink": "soft pink colored",
        "black": "pure black colored",
        "white": "clean white colored",
        "gold": "shimmering golden colored",
        "silver": "metallic silver colored",
        "brown": "natural brown colored",
        "gray": "neutral gray colored",
        "grey": "neutral grey colored",
        "cyan": "cyan turquoise colored",
        "magenta": "vivid magenta colored",
        "teal": "elegant teal colored",
        "navy": "deep navy blue colored",
        "maroon": "rich maroon colored",
        "olive": "earthy olive colored",
        "coral": "beautiful coral colored",
        "beige": "soft beige colored",
        "tan": "warm tan colored",
        "cream": "creamy off-white colored",
        "mint": "fresh mint green colored",
        "lavender": "delicate lavender colored",
        "rose": "romantic rose colored",
        "burgundy": "deep burgundy colored",
        "bronze": "warm bronze colored"
    }
    
    # Object synonyms for better detection
    OBJECT_SYNONYMS = {
        "person": ["person", "human", "man", "woman", "people", "guy", "girl", "boy", "lady", "gentleman", "individual", "figure", "someone", "somebody", "pedestrian"],
        "sky": ["sky", "clouds", "heaven", "atmosphere", "air above", "skyline"],
        "car": ["car", "vehicle", "automobile", "auto", "ride", "wheels", "sedan", "suv", "truck", "van"],
        "background": ["background", "backdrop", "behind", "scenery", "setting", "surroundings", "environment"],
        "text": ["text", "words", "letters", "writing", "inscription", "watermark", "logo", "signature", "label", "caption"],
        "grass": ["grass", "lawn", "turf", "field", "meadow", "greenery"],
        "tree": ["tree", "plant", "vegetation", "foliage", "bush", "shrub"],
        "water": ["water", "ocean", "sea", "lake", "river", "pond", "pool", "stream"],
        "building": ["building", "house", "structure", "architecture", "construction", "edifice"],
        "animal": ["animal", "pet", "creature", "dog", "cat", "bird"],
        "face": ["face", "facial", "head", "portrait", "visage"],
        "hair": ["hair", "hairstyle", "locks", "mane", "tresses"],
        "clothes": ["clothes", "clothing", "outfit", "dress", "shirt", "pants", "garment", "attire", "wear"],
        "wall": ["wall", "walls", "surface", "partition"],
        "floor": ["floor", "ground", "flooring", "surface below"],
        "window": ["window", "glass", "pane", "windowpane"],
        "door": ["door", "doorway", "entrance", "entry", "gate"]
    }
    
    # Scene/style transformations
    STYLE_TRANSFORMS = {
        "sunset": "beautiful golden sunset sky with orange and pink clouds, dramatic lighting",
        "sunrise": "stunning sunrise with warm golden light, peaceful morning atmosphere",
        "night": "dark nighttime scene with stars, moonlit atmosphere",
        "day": "bright daylight, clear blue sky, natural sunlight",
        "winter": "snowy winter scene, frost covered, cold atmosphere",
        "summer": "bright summer day, warm sunny atmosphere",
        "autumn": "fall colors, orange and brown leaves, autumn atmosphere",
        "spring": "fresh spring scene, blooming flowers, new growth",
        "rain": "rainy weather, wet surfaces, overcast sky",
        "snow": "heavy snowfall, white snow covered, winter wonderland",
        "foggy": "misty foggy atmosphere, soft diffused light",
        "stormy": "dramatic stormy sky, dark clouds, lightning",
        "vintage": "vintage retro aesthetic, warm sepia tones, nostalgic feel",
        "cyberpunk": "neon cyberpunk aesthetic, futuristic, glowing lights",
        "fantasy": "magical fantasy scene, ethereal atmosphere, dreamlike",
        "realistic": "photorealistic, natural, lifelike quality",
        "cartoon": "cartoon animated style, colorful, illustrated",
        "anime": "anime style, japanese animation aesthetic",
        "watercolor": "watercolor painting style, soft brushstrokes",
        "oil painting": "oil painting style, rich textures, artistic",
        "sketch": "pencil sketch style, hand-drawn look",
        "cinematic": "cinematic movie quality, dramatic lighting, film-like",
        "hdr": "high dynamic range, vivid colors, enhanced contrast",
        "dreamy": "soft dreamy atmosphere, ethereal glow, romantic",
        "dramatic": "dramatic lighting, high contrast, intense mood",
        "peaceful": "calm peaceful atmosphere, serene, tranquil",
        "scary": "dark scary atmosphere, horror aesthetic, ominous",
        "happy": "bright cheerful atmosphere, joyful, vibrant colors",
        "sad": "melancholic atmosphere, muted colors, somber mood"
    }
    
    def __init__(self):
        self.last_confidence = 0.0
        self.interpretation = ""
    
    def normalize_text(self, text: str) -> str:
        """Normalize input text for better parsing."""
        text = text.lower().strip()
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text)
        # Remove common punctuation that doesn't affect meaning
        text = re.sub(r'[.,!?;:]+$', '', text)
        # Handle contractions
        text = text.replace("don't", "do not")
        text = text.replace("can't", "cannot")
        text = text.replace("won't", "will not")
        text = text.replace("i'd", "i would")
        text = text.replace("i'm", "i am")
        text = text.replace("it's", "it is")
        return text
    
    def extract_target_object(self, text: str) -> str:
        """Extract the target object from the instruction."""
        # Remove common filler words
        filler_words = ["the", "a", "an", "this", "that", "those", "these", "my", "your", "please", "kindly", "can you", "could you", "would you", "i want to", "i'd like to", "i would like to"]
        result = text
        for filler in filler_words:
            result = re.sub(r'\b' + filler + r'\b', '', result, flags=re.IGNORECASE)
        return result.strip()
    
    def find_best_synonym(self, target: str) -> str:
        """Find the best matching object for CLIPSeg detection."""
        target_lower = target.lower()
        
        # Check if target matches any known synonym
        for main_object, synonyms in self.OBJECT_SYNONYMS.items():
            for synonym in synonyms:
                if synonym in target_lower or target_lower in synonym:
                    return main_object
        
        return target
    
    def enhance_prompt(self, prompt: str) -> str:
        """Enhance the replacement prompt for better results."""
        prompt_lower = prompt.lower()
        
        # Check for style transformations
        for style_key, style_value in self.STYLE_TRANSFORMS.items():
            if style_key in prompt_lower:
                return f"{style_value}, high quality, detailed, professional"
        
        # Check for colors and enhance
        for color_key, color_value in self.COLORS.items():
            if color_key in prompt_lower:
                prompt = prompt.replace(color_key, color_value)
        
        # Add quality modifiers if not present
        quality_terms = ["high quality", "detailed", "professional", "beautiful", "stunning"]
        has_quality = any(term in prompt_lower for term in quality_terms)
        
        if not has_quality:
            prompt = f"{prompt}, high quality, detailed, professional photography"
        
        return prompt
    
    def detect_action_type(self, text: str) -> str:
        """Detect the type of editing action requested."""
        text_lower = text.lower()
        
        for keyword in self.REMOVE_KEYWORDS:
            if keyword in text_lower:
                return "remove"
        
        for keyword in self.ADD_KEYWORDS:
            if keyword in text_lower:
                return "add"
        
        for keyword in self.REPLACE_KEYWORDS:
            if keyword in text_lower:
                return "replace"
        
        for keyword in self.CHANGE_KEYWORDS:
            if keyword in text_lower:
                return "change"
        
        for keyword in self.ENHANCE_KEYWORDS:
            if keyword in text_lower:
                return "enhance"
        
        return "general"
    
    def parse(self, instruction: str) -> Tuple[str, str, float]:
        """
        Parse the instruction and return (target, replacement_prompt, confidence).
        This is the main parsing method that handles all types of instructions.
        """
        original = instruction
        normalized = self.normalize_text(instruction)
        action_type = self.detect_action_type(normalized)
        
        target = ""
        replacement = ""
        confidence = 0.5
        
        # ===== REMOVE ACTION =====
        if action_type == "remove":
            for keyword in self.REMOVE_KEYWORDS:
                if keyword in normalized:
                    target = normalized.split(keyword, 1)[-1].strip()
                    break
            
            target = self.extract_target_object(target)
            target = self.find_best_synonym(target)
            replacement = "clean empty background, seamless natural texture, nothing there, blank space"
            confidence = 0.85
            self.interpretation = f"πŸ—‘οΈ Remove: Detecting and removing '{target}'"
        
        # ===== ADD ACTION =====
        elif action_type == "add":
            for keyword in self.ADD_KEYWORDS:
                if keyword in normalized:
                    parts = normalized.split(keyword, 1)
                    if len(parts) > 1:
                        target = "main subject area"
                        replacement = parts[1].strip()
                        break
            
            replacement = self.extract_target_object(replacement)
            replacement = self.enhance_prompt(replacement)
            confidence = 0.75
            self.interpretation = f"βž• Add: Adding '{replacement}' to the image"
        
        # ===== REPLACE ACTION =====
        elif action_type == "replace":
            # Try to find "X with Y" or "X to Y" patterns
            preposition_found = False
            for prep in self.PREPOSITIONS:
                if f" {prep} " in normalized:
                    parts = normalized.split(f" {prep} ", 1)
                    
                    # Extract target from first part
                    first_part = parts[0]
                    for keyword in self.REPLACE_KEYWORDS + self.CHANGE_KEYWORDS:
                        first_part = first_part.replace(keyword, "")
                    target = self.extract_target_object(first_part)
                    target = self.find_best_synonym(target)
                    
                    # Extract replacement from second part
                    replacement = self.extract_target_object(parts[1])
                    replacement = self.enhance_prompt(replacement)
                    
                    preposition_found = True
                    confidence = 0.9
                    break
            
            if not preposition_found:
                # Fallback: try to extract target and use generic replacement
                for keyword in self.REPLACE_KEYWORDS:
                    if keyword in normalized:
                        target = normalized.split(keyword, 1)[-1].strip()
                        target = self.extract_target_object(target)
                        target = self.find_best_synonym(target)
                        replacement = "something different, new object, alternative"
                        confidence = 0.6
                        break
            
            self.interpretation = f"πŸ”„ Replace: Replacing '{target}' with '{replacement[:50]}...'"
        
        # ===== CHANGE ACTION =====
        elif action_type == "change":
            # Look for patterns like "change X to Y" or "make X Y"
            preposition_found = False
            for prep in ["to", "into", "as"]:
                if f" {prep} " in normalized:
                    parts = normalized.split(f" {prep} ", 1)
                    
                    # Extract target from first part
                    first_part = parts[0]
                    for keyword in self.CHANGE_KEYWORDS:
                        first_part = first_part.replace(keyword, "")
                    target = self.extract_target_object(first_part)
                    target = self.find_best_synonym(target)
                    
                    # Extract new state from second part
                    new_state = self.extract_target_object(parts[1])
                    
                    # Combine target with new state for replacement
                    replacement = f"{target} that is {new_state}, {self.enhance_prompt(new_state)}"
                    
                    preposition_found = True
                    confidence = 0.85
                    break
            
            if not preposition_found:
                # Check for color changes like "make it red"
                for color in self.COLORS.keys():
                    if color in normalized:
                        target = "main subject"
                        replacement = f"{self.COLORS[color]}, high quality, detailed"
                        confidence = 0.8
                        preposition_found = True
                        break
            
            if not preposition_found:
                target = "main subject"
                replacement = self.enhance_prompt(normalized)
                confidence = 0.6
            
            self.interpretation = f"✏️ Change: Modifying '{target}' β†’ '{replacement[:50]}...'"
        
        # ===== ENHANCE ACTION =====
        elif action_type == "enhance":
            target = "main subject"
            replacement = "enhanced improved professional high quality detailed stunning beautiful"
            confidence = 0.7
            self.interpretation = f"✨ Enhance: Improving overall image quality"
        
        # ===== GENERAL/UNKNOWN ACTION =====
        else:
            # Try to intelligently guess from the instruction
            # Check if it's just a noun/object (user wants to remove it)
            words = normalized.split()
            if len(words) <= 3:
                target = self.find_best_synonym(normalized)
                replacement = "clean empty background, seamless natural texture"
                confidence = 0.5
                self.interpretation = f"πŸ€” Guessing: You might want to remove '{target}'?"
            else:
                # Treat as a creative prompt
                target = "main subject area"
                replacement = self.enhance_prompt(normalized)
                confidence = 0.5
                self.interpretation = f"🎨 Creative: Applying '{replacement[:50]}...'"
        
        # Final cleanup
        target = target.strip() if target else "main subject"
        replacement = replacement.strip() if replacement else "improved version"
        
        # Store confidence
        self.last_confidence = confidence
        
        return target, replacement, confidence


# Create global parser instance
gemini_parser = GeminiStyleParser()


def parse_instruction(instruction: str) -> Tuple[str, str]:
    """
    Enhanced parsing function that uses the GeminiStyleParser.
    Maintains backward compatibility with existing code.
    """
    target, replacement, _ = gemini_parser.parse(instruction)
    return target, replacement


# ============================================================================
# MODEL CACHING
# ============================================================================

@st.cache_resource
def load_inpaint_pipeline():
    """Load and cache the inpainting pipeline."""
    from diffusers import StableDiffusionInpaintPipeline
    
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        INPAINT_MODEL,
        torch_dtype=DTYPE,
        safety_checker=None,
        requires_safety_checker=False
    )
    
    pipe = pipe.to(DEVICE)
    
    if DEVICE == "cuda":
        pipe.enable_attention_slicing()
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception:
            pass
    else:
        pipe.enable_attention_slicing(1)
    
    return pipe


@st.cache_resource
def load_clipseg():
    """Load and cache CLIPSeg for automatic mask generation."""
    from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
    
    processor = CLIPSegProcessor.from_pretrained(CLIPSEG_MODEL)
    model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL)
    model = model.to(DEVICE)
    model.eval()
    
    return processor, model


# ============================================================================
# MASK GENERATION (Enhanced)
# ============================================================================

def generate_mask_clipseg(
    image: Image.Image,
    target_text: str,
    threshold: float = 0.3,
    expand_pixels: int = 10
) -> Optional[Image.Image]:
    """Generate a segmentation mask using CLIPSeg with enhanced detection."""
    try:
        processor, model = load_clipseg()
        
        # Try multiple variations of the target text for better detection
        target_variations = [
            target_text,
            f"a {target_text}",
            f"the {target_text}",
            f"{target_text} in photo",
            f"photo of {target_text}"
        ]
        
        best_mask = None
        best_score = 0
        
        for variation in target_variations:
            inputs = processor(
                text=[variation],
                images=[image],
                padding=True,
                return_tensors="pt"
            )
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs)
                preds = outputs.logits
            
            pred = torch.sigmoid(preds[0]).cpu().numpy()
            score = pred.max()
            
            if score > best_score:
                best_score = score
                best_mask = pred
        
        if best_mask is None:
            return None
        
        # Resize to original image size
        pred_pil = Image.fromarray((best_mask * 255).astype(np.uint8))
        pred_resized = pred_pil.resize(image.size, Image.BILINEAR)
        pred_array = np.array(pred_resized)
        
        # Apply threshold
        mask = (pred_array > (threshold * 255)).astype(np.uint8) * 255
        
        # Expand mask
        if expand_pixels > 0:
            from PIL import ImageFilter
            mask_image = Image.fromarray(mask, mode="L")
            mask_image = mask_image.filter(
                ImageFilter.MaxFilter(size=expand_pixels * 2 + 1)
            )
            mask_image = mask_image.filter(
                ImageFilter.GaussianBlur(radius=3)
            )
            return mask_image
        
        return Image.fromarray(mask, mode="L")
    
    except Exception as e:
        st.error(f"Mask generation error: {str(e)}")
        return None


def process_manual_mask(mask_image: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
    """Process a manually uploaded mask."""
    mask = mask_image.convert("L")
    mask = mask.resize(target_size, Image.LANCZOS)
    mask_array = np.array(mask)
    mask_array = ((mask_array > 127) * 255).astype(np.uint8)
    return Image.fromarray(mask_array, mode="L")


# ============================================================================
# IMAGE INPAINTING (Enhanced)
# ============================================================================

def inpaint_image(
    image: Image.Image,
    mask: Image.Image,
    prompt: str,
    negative_prompt: str = "blurry, bad quality, distorted, ugly, deformed, low resolution, pixelated, jpeg artifacts, watermark, text, logo",
    num_inference_steps: int = 30,
    guidance_scale: float = 7.5
) -> Optional[Image.Image]:
    """Inpaint the masked region of an image with enhanced prompts."""
    try:
        pipe = load_inpaint_pipeline()
        
        # Resize for SD (512x512)
        original_size = image.size
        target_size = (512, 512)
        
        image_resized = image.resize(target_size, Image.LANCZOS)
        mask_resized = mask.resize(target_size, Image.NEAREST)
        
        if image_resized.mode != "RGB":
            image_resized = image_resized.convert("RGB")
        
        # Adjust steps for CPU
        if DEVICE == "cpu":
            num_inference_steps = min(num_inference_steps, 20)
        
        # Enhanced prompt engineering
        enhanced_prompt = f"{prompt}, masterpiece, best quality, highly detailed, sharp focus, professional"
        
        with torch.inference_mode():
            result = pipe(
                prompt=enhanced_prompt,
                negative_prompt=negative_prompt,
                image=image_resized,
                mask_image=mask_resized,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale
            ).images[0]
        
        result = result.resize(original_size, Image.LANCZOS)
        
        if DEVICE == "cpu":
            gc.collect()
        
        return result
    
    except Exception as e:
        st.error(f"Inpainting error: {str(e)}")
        return None


# ============================================================================
# CUSTOM CSS FOR PRO LOOK
# ============================================================================

def inject_custom_css():
    """Inject custom CSS for a more professional look."""
    st.markdown("""
    <style>
    /* Dark theme with gradients */
    .stApp {
        background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%);
    }
    
    /* Styled headers */
    h1 {
        background: linear-gradient(90deg, #e94560, #0f3460);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        font-size: 2.5rem !important;
    }
    
    /* Card-like containers */
    .stButton > button {
        background: linear-gradient(90deg, #e94560, #533483);
        border: none;
        border-radius: 10px;
        font-weight: bold;
        transition: all 0.3s ease;
    }
    
    .stButton > button:hover {
        transform: translateY(-2px);
        box-shadow: 0 5px 20px rgba(233, 69, 96, 0.4);
    }
    
    /* Styled file uploader */
    .stFileUploader {
        border: 2px dashed #e94560;
        border-radius: 15px;
        padding: 20px;
    }
    
    /* Confidence indicator */
    .confidence-high {
        color: #4ade80;
        font-weight: bold;
    }
    
    .confidence-medium {
        color: #fbbf24;
        font-weight: bold;
    }
    
    .confidence-low {
        color: #f87171;
        font-weight: bold;
    }
    
    /* Interpretation box */
    .interpretation-box {
        background: rgba(233, 69, 96, 0.1);
        border-left: 4px solid #e94560;
        padding: 10px 15px;
        border-radius: 0 10px 10px 0;
        margin: 10px 0;
    }
    
    /* Pro badge */
    .pro-badge {
        background: linear-gradient(90deg, #e94560, #533483);
        padding: 2px 10px;
        border-radius: 20px;
        font-size: 0.8rem;
        font-weight: bold;
        color: white;
    }
    
    /* Smooth transitions */
    * {
        transition: background-color 0.3s ease, color 0.3s ease;
    }
    </style>
    """, unsafe_allow_html=True)


# ============================================================================
# MAIN APP
# ============================================================================

def main():
    inject_custom_css()
    
    st.markdown("""
    <div style="display: flex; align-items: center; gap: 10px;">
        <h1>🎨 AI Image Editor</h1>
        <span class="pro-badge">PRO</span>
    </div>
    """, unsafe_allow_html=True)
    
    st.markdown("**Gemini-style image editing with advanced prompt understanding - 100% Private!**")
    
    # Sidebar
    with st.sidebar:
        st.header("βš™οΈ Settings")
        
        auto_mask = st.checkbox(
            "πŸ” Auto-detect region",
            value=True,
            help="Automatically find the object to edit using AI"
        )
        
        st.markdown("---")
        st.subheader("🎚️ Advanced Options")
        
        mask_threshold = st.slider(
            "Detection Sensitivity",
            min_value=0.1,
            max_value=0.9,
            value=0.25,
            step=0.05,
            help="Lower = larger detection area"
        )
        
        mask_expansion = st.slider(
            "Mask Expansion (px)",
            min_value=0,
            max_value=50,
            value=15,
            step=2,
            help="Expand the detected area for better blending"
        )
        
        num_steps = st.slider(
            "Quality Steps",
            min_value=10,
            max_value=50,
            value=20 if DEVICE == "cpu" else 35,
            step=5,
            help="More = better quality but slower"
        )
        
        guidance_scale = st.slider(
            "Prompt Strength",
            min_value=1.0,
            max_value=15.0,
            value=8.5,
            step=0.5,
            help="Higher = more closely follows your instructions"
        )
        
        st.markdown("---")
        device_emoji = "πŸš€" if DEVICE == "cuda" else "πŸ’»"
        st.info(f"{device_emoji} Device: **{DEVICE.upper()}**")
        
        if DEVICE == "cpu":
            st.warning("⚠️ Running on CPU. Edits may take 1-3 minutes.")
        else:
            st.success("βœ… GPU detected! Fast processing enabled.")
    
    # Main content
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("πŸ“· Upload Image")
        uploaded_file = st.file_uploader(
            "Choose an image",
            type=["png", "jpg", "jpeg", "webp", "bmp"],
            label_visibility="collapsed"
        )
        
        image = None
        if uploaded_file is not None:
            image = Image.open(uploaded_file).convert("RGB")
            st.image(image, caption="Original Image", use_container_width=True)
        
        st.subheader("✏️ What would you like to change?")
        instruction = st.text_area(
            "Describe your edit naturally",
            placeholder="Examples:\nβ€’ 'Remove the person in the background'\nβ€’ 'Replace the sky with a sunset'\nβ€’ 'Make the car red'\nβ€’ 'Add a rainbow'\nβ€’ 'Turn the grass into snow'\nβ€’ 'Delete the watermark'",
            label_visibility="collapsed",
            height=120
        )
        
        # Show interpretation preview
        if instruction:
            target_preview, replacement_preview, confidence = gemini_parser.parse(instruction)
            
            confidence_class = "high" if confidence >= 0.8 else "medium" if confidence >= 0.6 else "low"
            confidence_pct = int(confidence * 100)
            
            st.markdown(f"""
            <div class="interpretation-box">
                <strong>🧠 Understanding:</strong> {gemini_parser.interpretation}<br>
                <span class="confidence-{confidence_class}">Confidence: {confidence_pct}%</span>
            </div>
            """, unsafe_allow_html=True)
        
        mask_file = None
        if not auto_mask:
            st.subheader("πŸ“ Manual Mask")
            mask_file = st.file_uploader(
                "Upload a black & white mask (white = area to edit)",
                type=["png", "jpg", "jpeg"],
                key="mask"
            )
        
        edit_clicked = st.button(
            "🎨 Apply Edit",
            type="primary",
            use_container_width=True,
            disabled=(uploaded_file is None or not instruction)
        )
    
    with col2:
        st.subheader("✨ Result")
        result_placeholder = st.empty()
        mask_placeholder = st.empty()
        status_placeholder = st.empty()
        download_placeholder = st.empty()
        
        if edit_clicked and image is not None and instruction:
            try:
                target, replacement_prompt, confidence = gemini_parser.parse(instruction)
                
                status_placeholder.info(f"🎯 **Target:** `{target}`\n\n✨ **Generating:** `{replacement_prompt[:100]}...`")
                
                # Generate mask
                if mask_file is not None:
                    mask_img = Image.open(mask_file)
                    final_mask = process_manual_mask(mask_img, image.size)
                    status_placeholder.info("πŸ“ Using manual mask...")
                elif auto_mask:
                    with st.spinner(f"πŸ” AI detecting '{target}'..."):
                        final_mask = generate_mask_clipseg(
                            image=image,
                            target_text=target,
                            threshold=mask_threshold,
                            expand_pixels=mask_expansion
                        )
                    if final_mask is None:
                        st.error("Failed to generate mask")
                        st.stop()
                else:
                    st.error("Please upload a mask or enable auto-detection!")
                    st.stop()
                
                # Check mask has content
                mask_array = np.array(final_mask)
                if mask_array.max() < 128:
                    st.warning(f"⚠️ Could not confidently detect '{target}'. Trying with broader detection...")
                    # Retry with lower threshold
                    final_mask = generate_mask_clipseg(
                        image=image,
                        target_text=target,
                        threshold=mask_threshold * 0.5,
                        expand_pixels=mask_expansion * 2
                    )
                    if final_mask is None or np.array(final_mask).max() < 128:
                        st.error(f"❌ Still could not detect '{target}'. Try different wording or upload a mask.")
                        st.stop()
                
                mask_placeholder.image(final_mask, caption="🎭 Detected Area", use_container_width=True)
                
                # Inpaint
                with st.spinner("🎨 AI is editing your image... This may take a moment."):
                    result = inpaint_image(
                        image=image,
                        mask=final_mask,
                        prompt=replacement_prompt,
                        num_inference_steps=num_steps,
                        guidance_scale=guidance_scale
                    )
                
                if result is not None:
                    result_placeholder.image(result, caption="βœ… Edited Image", use_container_width=True)
                    status_placeholder.success("βœ… Edit complete!")
                    
                    buf = BytesIO()
                    result.save(buf, format="PNG")
                    download_placeholder.download_button(
                        label="πŸ“₯ Download Result",
                        data=buf.getvalue(),
                        file_name="edited_image.png",
                        mime="image/png",
                        use_container_width=True
                    )
                else:
                    st.error("Inpainting failed")
                
            except Exception as e:
                st.error(f"❌ Error: {str(e)}")
        
        elif uploaded_file is None:
            result_placeholder.info("πŸ‘† Upload an image to get started")
    
    # Enhanced Examples Section
    st.markdown("---")
    st.subheader("πŸ’‘ Pro Tips & Examples")
    
    c1, c2, c3, c4 = st.columns(4)
    
    with c1:
        st.markdown("""
        **πŸ—‘οΈ Remove Objects:**
        - `remove the person`
        - `delete the watermark`
        - `erase the car`
        - `get rid of the background`
        - `take out the text`
        """)
    
    with c2:
        st.markdown("""
        **πŸ”„ Replace Objects:**
        - `replace sky with sunset`
        - `swap the car with a bike`
        - `change background to beach`
        - `turn grass into snow`
        """)
    
    with c3:
        st.markdown("""
        **🎨 Change Colors:**
        - `make the car red`
        - `change dress to blue`
        - `turn hair blonde`
        - `paint walls white`
        """)
    
    with c4:
        st.markdown("""
        **✨ Transform Styles:**
        - `make it sunset lighting`
        - `turn into winter scene`
        - `add cyberpunk aesthetic`
        - `make it cinematic`
        """)
    
    st.markdown("---")
    st.markdown(
        """<center>πŸ”’ <b>Privacy First</b> - All processing happens locally. No data sent to external APIs.<br>
        <span style="color: #888;">Powered by Stable Diffusion + CLIPSeg | Created with ❀️</span></center>""",
        unsafe_allow_html=True
    )


if __name__ == "__main__":
    main()