import gradio as gr import torch from diffusers import FluxPipeline from transformers import CLIPTextModel, T5EncoderModel, CLIPTokenizer, T5Tokenizer from safetensors.torch import load_file import os import socket from PIL import Image import base64 import io import requests import json def find_free_port(start_port=7860): """Find a free port""" for port in range(start_port, start_port + 20): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind(('localhost', port)) return port except OSError: continue return None class CompleteLocalFlux: def __init__(self): # Set up Groq API key (you'll need to set this) self.groq_api_key = os.getenv("GROQ_API_KEY") if not self.groq_api_key: print("โš ๏ธ GROQ_API_KEY not found in environment variables") print(" Set it with: export GROQ_API_KEY='your_api_key_here'") else: print("โœ… Groq API key found") if torch.backends.mps.is_available(): self.device = torch.device("mps") print("๐Ÿš€ Using Apple M2 Max with MPS") else: self.device = torch.device("cpu") # Find your models self.flux_models = {} self.local_t5_path = None # Check for Flux models possible_flux_files = [ ("Flux Dev", "./models/Flux/flux-dev.safetensors"), ("Flux Schnell", "./models/Flux/flux1-schnell.safetensors"), ("Flux Kontex", "./models/Flux/flux-kontex.safetensors"), ("Flux Dev Alt", "./flux-dev.safetensors"), ("Flux Schnell Alt", "./flux1-schnell.safetensors"), ("Flux Kontex Alt", "./flux-kontex.safetensors") ] for name, path in possible_flux_files: if os.path.exists(path): size_gb = os.path.getsize(path) / (1024*1024*1024) self.flux_models[name] = {"path": path, "size": size_gb} print(f"โœ… Found {name}: {path} ({size_gb:.1f} GB)") # Check for local T5 model possible_t5_paths = [ "./models/Flux/google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors", "./google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors", "./models/google_t5-v1_1-xxl_encoderonly-fp8_e4m3fn.safetensors" ] for path in possible_t5_paths: if os.path.exists(path): size_gb = os.path.getsize(path) / (1024*1024*1024) self.local_t5_path = path print(f"โœ… Found T5 model: {path} ({size_gb:.1f} GB)") break # Check for local CLIP model self.local_clip_path = None possible_clip_paths = [ "./models/clip", "./models/CLIP/clip-vit-large-patch14", "./clip-vit-large-patch14" ] for path in possible_clip_paths: if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")): self.local_clip_path = path print(f"โœ… Found local CLIP model: {path}") break if not self.local_clip_path: print("โš ๏ธ No local CLIP model found - will download on first use") # Check for local VAE model (including downloaded cache) self.local_vae_path = None self.cached_vae_path = "./models/Flux/vae_cache" # Cache directory for downloaded VAE possible_vae_paths = [ "./models/Flux/vae_local", # New local VAE location "./models/Flux/ae.safetensors", "./ae.safetensors", "./models/ae.safetensors", "./models/Flux/vae.safetensors", "./vae.safetensors", self.cached_vae_path # Check for cached downloaded VAE ] for path in possible_vae_paths: if os.path.exists(path): if os.path.isdir(path): # Cached VAE directory self.local_vae_path = path print(f"โœ… Found cached VAE: {path}") else: # Single VAE file size_gb = os.path.getsize(path) / (1024*1024*1024) self.local_vae_path = path print(f"โœ… Found VAE model: {path} ({size_gb:.1f} GB)") break # Find LoRA files - simple and working approach self.lora_files = [] # Check multiple directories for LoRA files lora_search_paths = [ "./models/lora", # Main LoRA directory ".", # Current directory "./models", "./lora", "./LoRA" ] for search_path in lora_search_paths: if os.path.exists(search_path): try: files = [f for f in os.listdir(search_path) if f.endswith(".safetensors")] # Add full path for files not in current directory if search_path != ".": files = [os.path.join(search_path, f) for f in files] self.lora_files.extend(files) except PermissionError: continue # Also specifically look for your LoRA files specific_lora_files = [ "./models/lora/act_person_trained.safetensors", "./models/lora/oddtoperson.safetensors", "./models/lora/oddtopersonmark2.safetensors", ] for lora_file in specific_lora_files: if os.path.exists(lora_file) and lora_file not in self.lora_files: self.lora_files.append(lora_file) # Remove duplicates while preserving order seen = set() unique_lora_files = [] for f in self.lora_files: if f not in seen: seen.add(f) unique_lora_files.append(f) self.lora_files = unique_lora_files self.pipeline = None self.current_model = None self.lora_loaded = False self.encoders_loaded = False print(f"โœ… Found {len(self.lora_files)} LoRA files") for f in self.lora_files: print(f" - {f}") def cleanup_memory(self): """Clean up GPU/MPS memory""" if hasattr(self, 'pipeline') and self.pipeline is not None: del self.pipeline self.pipeline = None if torch.cuda.is_available(): torch.cuda.empty_cache() elif self.device.type == "mps": torch.mps.empty_cache() print("๐Ÿงน Memory cleaned up") def load_local_text_encoders(self): """Load text encoders using local and remote models""" try: print("๐Ÿ”„ Loading text encoders...") # Use consistent dtype for MPS compatibility dtype = torch.float32 # Use float32 for better MPS compatibility # Load CLIP text encoder from local folder if available if self.local_clip_path: print(f" Loading CLIP from local folder: {self.local_clip_path}") try: self.clip_text_encoder = CLIPTextModel.from_pretrained( self.local_clip_path, torch_dtype=dtype, local_files_only=True # Force local only ) self.clip_tokenizer = CLIPTokenizer.from_pretrained( self.local_clip_path, local_files_only=True # Force local only ) print("โœ… Local CLIP model loaded successfully!") except Exception as e: print(f"โŒ Error loading local CLIP folder: {e}") print(" Falling back to download...") # Fallback to download if local fails self.clip_text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=dtype ) self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") else: print(" Loading CLIP text encoder (downloading ~1GB)...") self.clip_text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=dtype ) self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # Load T5 encoder - fix the tokenizer warning and local loading if self.local_t5_path: print(f" Loading T5 from local file: {self.local_t5_path}") # Load tokenizer with legacy=False to suppress warning print(" Loading T5 tokenizer...") self.t5_tokenizer = T5Tokenizer.from_pretrained( "google/t5-v1_1-xxl", legacy=False # This fixes the warning ) print(" Loading local T5 weights...") # Load the model architecture first self.t5_text_encoder = T5EncoderModel.from_pretrained( "google/t5-v1_1-xxl", torch_dtype=dtype ) # Try to load and apply your local weights try: print(" Attempting to load local T5 safetensors...") local_t5_weights = load_file(self.local_t5_path) # Filter weights to only include those that match the model structure model_state_dict = self.t5_text_encoder.state_dict() filtered_weights = {} for key, value in local_t5_weights.items(): if key in model_state_dict: if model_state_dict[key].shape == value.shape: filtered_weights[key] = value else: print(f"โš ๏ธ Skipping {key}: shape mismatch {model_state_dict[key].shape} vs {value.shape}") else: print(f"โš ๏ธ Skipping {key}: not found in model") # Load the filtered weights missing_keys, unexpected_keys = self.t5_text_encoder.load_state_dict(filtered_weights, strict=False) if missing_keys: print(f"โš ๏ธ Missing keys: {len(missing_keys)} (this is often normal)") if unexpected_keys: print(f"โš ๏ธ Unexpected keys: {len(unexpected_keys)}") print("โœ… Local T5 weights loaded successfully!") except Exception as e: print(f"โŒ Error loading local T5 weights: {e}") print(" Your T5 file may be corrupted or incomplete.") print(" Falling back to downloaded weights (model architecture already loaded)...") # Keep the downloaded model architecture - don't try to reload else: print(" No local T5 found, downloading...") self.t5_tokenizer = T5Tokenizer.from_pretrained( "google/t5-v1_1-xxl", legacy=False # This fixes the warning ) self.t5_text_encoder = T5EncoderModel.from_pretrained( "google/t5-v1_1-xxl", torch_dtype=dtype ) # Move to device self.clip_text_encoder = self.clip_text_encoder.to(self.device) self.t5_text_encoder = self.t5_text_encoder.to(self.device) self.encoders_loaded = True print("โœ… All text encoders loaded successfully!") return True except Exception as e: print(f"โŒ Error loading text encoders: {e}") import traceback traceback.print_exc() # This will help debug the exact issue return False def load_flux_complete(self, model_choice, lora_choice): """Load complete Flux setup with better memory management""" try: # Clean up previous model if switching if self.current_model and self.current_model != model_choice: print("๐Ÿงน Cleaning up previous model...") self.cleanup_memory() # Load encoders if needed if not self.encoders_loaded: if not self.load_local_text_encoders(): return "โŒ Failed to load text encoders" if model_choice not in self.flux_models: return f"โŒ Model {model_choice} not found" model_path = self.flux_models[model_choice]["path"] print(f"๐Ÿ”„ Loading {model_choice} with complete setup...") # Load VAE separately (required for Flux) print(" Loading VAE...") from diffusers import AutoencoderKL # Check if we have a local VAE first if self.local_vae_path: print(f" Using local VAE from: {self.local_vae_path}") try: if os.path.isdir(self.local_vae_path): # Local VAE folder vae = AutoencoderKL.from_pretrained( self.local_vae_path, torch_dtype=torch.float32, local_files_only=True # Force local only ) else: # Single VAE file - load the base model and apply weights vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.float32 ) # Load local weights if it's a safetensors file if self.local_vae_path.endswith('.safetensors'): from safetensors.torch import load_file vae_weights = load_file(self.local_vae_path) vae.load_state_dict(vae_weights, strict=False) # Ensure all VAE weights are float32 for MPS compatibility vae = vae.to(torch.float32) print("โœ… Local VAE loaded successfully!") except Exception as e: print(f"โŒ Local VAE failed: {e}") print(" Falling back to download...") vae = None else: vae = None # Download and cache VAE if no local version works if vae is None: print(" โš ๏ธ No local VAE found - downloading from HuggingFace...") print(" Consider running download_vae.py for 100% local operation") try: # Create cache directory os.makedirs(os.path.dirname(self.cached_vae_path), exist_ok=True) # Download and save to cache vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.float32, cache_dir="./models/Flux/hf_cache" # Local cache for HuggingFace downloads ) # Ensure all VAE weights are float32 for MPS compatibility vae = vae.to(torch.float32) # Save the VAE locally for next time print(f" Caching VAE to: {self.cached_vae_path}") vae.save_pretrained(self.cached_vae_path) self.local_vae_path = self.cached_vae_path # Update for future runs print("โœ… VAE downloaded and cached locally!") except Exception as e: print(f"โŒ Failed to download VAE: {e}") return f"โŒ Could not load VAE: {e}" # Load Flux with all components including VAE self.pipeline = FluxPipeline.from_single_file( model_path, text_encoder=self.clip_text_encoder, text_encoder_2=self.t5_text_encoder, tokenizer=self.clip_tokenizer, tokenizer_2=self.t5_tokenizer, vae=vae, # Add the VAE component torch_dtype=torch.float32, # Use float32 for MPS compatibility ) self.current_model = model_choice print(f"โœ… {model_choice} loaded completely!") # Load LoRA self.lora_loaded = False if lora_choice != "None" and lora_choice in self.lora_files: try: print(f"๐Ÿ”„ Loading LoRA: {lora_choice}") # Load LoRA with better error handling and warnings suppression import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="No LoRA keys associated to CLIPTextModel found") warnings.filterwarnings("ignore", message="You can also try specifying") self.pipeline.load_lora_weights(".", weight_name=lora_choice) self.lora_loaded = True print("โœ… LoRA loaded successfully!") except Exception as e: print(f"โŒ LoRA loading failed: {e}") # Continue without LoRA if it fails self.lora_loaded = False # Move pipeline to device (MPS for Apple Silicon) self.pipeline = self.pipeline.to(self.device) # Ensure all pipeline components are float32 for MPS compatibility if self.device.type == "mps": print(" Converting all components to float32 for MPS...") self.pipeline.vae = self.pipeline.vae.to(torch.float32) self.pipeline.text_encoder = self.pipeline.text_encoder.to(torch.float32) self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.to(torch.float32) # Enable MPS-specific optimizations self.pipeline.enable_attention_slicing() print("โœ… Enabled MPS optimizations and float32 conversion") status = f"โœ… {model_choice} ready" if self.local_t5_path: status += " (local T5)" if self.local_clip_path: status += " (local CLIP)" if self.local_vae_path: status += " (local VAE)" if self.lora_loaded: status += f" + LoRA ({lora_choice})" return status except Exception as e: print(f"โŒ Error in load_flux_complete: {e}") import traceback traceback.print_exc() return f"โŒ Error: {e}" def generate_image(self, prompt, model_choice, lora_choice, steps, guidance, seed): """Generate with complete local setup - YOUR SETTINGS ARE RESPECTED""" # Convert clean LoRA name back to full path if needed actual_lora_choice = lora_choice if hasattr(self, 'lora_path_mapping') and lora_choice in self.lora_path_mapping: actual_lora_choice = self.lora_path_mapping[lora_choice] # Load if needed if self.pipeline is None or self.current_model != model_choice: print(f"๐Ÿ”„ Need to load model: {model_choice}") load_status = self.load_flux_complete(model_choice, actual_lora_choice) if "โŒ" in load_status: print(f"โŒ Model loading failed: {load_status}") return None, load_status if not prompt.strip(): return None, "โŒ Please enter a prompt" try: print(f"๐ŸŽจ Starting generation...") print(f" Prompt: {prompt[:60]}...") print(f" Model: {model_choice}") print(f" LoRA: {lora_choice}") print(f" Steps: {steps}, Guidance: {guidance}, Seed: {seed}") torch.manual_seed(int(seed)) # USE YOUR EXACT SETTINGS - NO OVERRIDES! print(f" Using your exact settings: {steps} steps, guidance: {guidance}") print("๐Ÿ”„ Running pipeline...") with torch.inference_mode(): result = self.pipeline( prompt=prompt, num_inference_steps=int(steps), guidance_scale=guidance, width=1024, height=1024, generator=torch.Generator(device=self.device).manual_seed(int(seed)) ) if hasattr(result, 'images') and len(result.images) > 0: image = result.images[0] print("โœ… Image generated successfully!") else: print("โŒ No images in pipeline result") return None, "โŒ Pipeline returned no images" if self.device.type == "mps": torch.mps.empty_cache() # Save with clean filename os.makedirs("outputs/complete_local_flux", exist_ok=True) model_name = model_choice.replace(" ", "_").lower() # Clean LoRA name for filename if lora_choice != "None": lora_name = os.path.basename(lora_choice).replace(".safetensors", "") lora_name = lora_name.replace("/", "_").replace("\\", "_").replace(" ", "_") else: lora_name = "no_lora" filename = f"{model_name}_{lora_name}_{seed}.png" filepath = os.path.join("outputs/complete_local_flux", filename) print(f"๐Ÿ’พ Saving to: {filepath}") image.save(filepath, optimize=True) status = f"โœ… Generated with {model_choice}" if self.lora_loaded: status += f" + LoRA" if self.local_t5_path: status += " (local T5)" status += f"\n๐Ÿ“ 1024x1024 โ€ข {steps} steps โ€ข Guidance: {guidance} โ€ข Seed: {seed}" status += f"\n๐Ÿ’พ {filepath}" print("๐ŸŽ‰ Generation complete!") return image, status except Exception as e: error_msg = f"โŒ Generation failed: {str(e)}" print(error_msg) import traceback traceback.print_exc() return None, error_msg def image_to_base64(self, image): """Convert PIL Image to base64 string""" try: # Resize image if too large (Groq has size limits) max_size = 1024 if image.width > max_size or image.height > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Convert to base64 buffered = io.BytesIO() image.save(buffered, format="JPEG", quality=85) img_str = base64.b64encode(buffered.getvalue()).decode() return img_str except Exception as e: print(f"โŒ Error converting image to base64: {e}") return None def analyze_image_with_groq(self, image): """Analyze image using Groq Vision API and return description""" if not self.groq_api_key: return "โŒ Groq API key not configured. Set GROQ_API_KEY environment variable." try: print("๐Ÿ” Analyzing image with Groq Vision...") # Convert image to base64 base64_image = self.image_to_base64(image) if not base64_image: return "โŒ Failed to convert image to base64" # Prepare the API request headers = { "Authorization": f"Bearer {self.groq_api_key}", "Content-Type": "application/json" } payload = { "model": "meta-llama/llama-4-scout-17b-16e-instruct", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in detail for an AI image generation prompt. Focus on visual elements, style, composition, lighting, colors, mood, and artistic techniques. Be descriptive but concise. Format it as a prompt that could be used to recreate a similar image." }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}" } } ] } ], "max_tokens": 300, "temperature": 0.3 } # Make the API call response = requests.post( "https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30 ) if response.status_code == 200: result = response.json() description = result['choices'][0]['message']['content'].strip() print("โœ… Image analysis complete!") return description else: error_msg = f"Groq API error: {response.status_code} - {response.text}" print(f"โŒ {error_msg}") return f"โŒ {error_msg}" except Exception as e: error_msg = f"Error analyzing image: {str(e)}" print(f"โŒ {error_msg}") return f"โŒ {error_msg}" def create_interface(self): """Create complete interface""" model_choices = list(self.flux_models.keys()) if not model_choices: model_choices = ["No models found"] # Clean up LoRA choices - show only the filename clean_lora_choices = ["None"] for lora_path in self.lora_files: filename = os.path.basename(lora_path) # Get just the filename clean_lora_choices.append(filename) # Create a mapping from clean names to full paths self.lora_path_mapping = {"None": "None"} for lora_path in self.lora_files: filename = os.path.basename(lora_path) self.lora_path_mapping[filename] = lora_path with gr.Blocks(title="Complete Local Flux Studio", theme=gr.themes.Soft()) as interface: gr.Markdown("# ๐Ÿ  Complete Local Flux Studio") gr.Markdown("*Using your local Flux models + T5 + LoRA - maximum efficiency!*") # Show what's available locally if self.flux_models: gr.Markdown("## ๐Ÿ“ Your Local Setup:") for name, info in self.flux_models.items(): gr.Markdown(f"- **{name}**: {info['size']:.1f} GB") if self.local_t5_path: t5_size = os.path.getsize(self.local_t5_path) / (1024*1024*1024) gr.Markdown(f"- **T5 Encoder**: {t5_size:.1f} GB (local)") if self.local_clip_path: clip_file = os.path.join(self.local_clip_path, "model.safetensors") if os.path.exists(clip_file): clip_size = os.path.getsize(clip_file) / (1024*1024*1024) gr.Markdown(f"- **CLIP Encoder**: {clip_size:.1f} GB (local)") else: gr.Markdown(f"- **CLIP Encoder**: local folder found") if self.local_vae_path: if os.path.isdir(self.local_vae_path): gr.Markdown(f"- **VAE**: cached (local)") else: vae_size = os.path.getsize(self.local_vae_path) / (1024*1024*1024) gr.Markdown(f"- **VAE**: {vae_size:.1f} GB (local)") gr.Markdown(f"- **LoRA Models**: {len(self.lora_files)} found") # IMAGE ANALYSIS SECTION - MOVED TO TOP LEVEL gr.Markdown("## ๐Ÿ” Image Analysis with Groq Vision") gr.Markdown("*Upload an image to automatically generate a prompt description*") input_image = gr.Image( label="๐Ÿ“ค Upload Image to Analyze", type="pil", height=200 ) analyze_btn = gr.Button( "๐Ÿ” Analyze Image with Groq Vision", variant="primary", size="lg" ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## ๐ŸŽจ Generate") model_choice = gr.Dropdown( choices=model_choices, value=model_choices[0] if model_choices[0] != "No models found" else None, label="Flux Model" ) lora_choice = gr.Dropdown( choices=clean_lora_choices, value=clean_lora_choices[1] if len(clean_lora_choices) > 1 else "None", label="Your LoRA" ) prompt = gr.Textbox( label="Prompt", value="artistic lifestyle portrait, person wearing vibrant orange bucket hat, expressive face, golden hour lighting, street style photography, film aesthetic", lines=6, placeholder="Enter your prompt here, or upload an image above and click 'Analyze' to auto-generate..." ) with gr.Row(): steps = gr.Slider(4, 50, value=20, label="Steps") guidance = gr.Slider(0.0, 10.0, value=3.5, label="Guidance") seed = gr.Number(value=42, label="Seed") generate_btn = gr.Button("๐Ÿ  Generate Locally", variant="primary", size="lg") with gr.Column(scale=1): output_image = gr.Image(label="Generated Image", height=600) status = gr.Textbox(label="Status", interactive=False, lines=4) # Quick prompts for your artistic style gr.Markdown("## ๐ŸŽจ Your Artistic Style") with gr.Row(): portrait_btn = gr.Button("๐ŸŽญ Portrait") vibrant_btn = gr.Button("๐ŸŒˆ Vibrant") street_btn = gr.Button("๐Ÿ“ธ Street") # Event handlers analyze_btn.click( fn=self.analyze_image_with_groq, inputs=[input_image], outputs=[prompt] ) portrait_btn.click( lambda: "artistic lifestyle portrait, person with expressive face, vibrant clothing, golden hour lighting", outputs=[prompt] ) vibrant_btn.click( lambda: "person in colorful streetwear, vibrant orange bucket hat, street photography, film aesthetic", outputs=[prompt] ) street_btn.click( lambda: "urban street style portrait, candid expression, natural lighting, contemporary photography", outputs=[prompt] ) generate_btn.click( fn=self.generate_image, inputs=[prompt, model_choice, lora_choice, steps, guidance, seed], outputs=[output_image, status] ) return interface def launch(self): """Launch complete interface""" interface = self.create_interface() port = find_free_port() print("๐Ÿ  Launching Complete Local Flux Studio...") print(f"๐Ÿ“ฑ Interface: http://localhost:{port}") print("๐Ÿš€ Using maximum local resources!") try: interface.launch( server_port=port, share=True, inbrowser=True ) except Exception as e: print(f"โŒ Launch failed: {e}") if __name__ == "__main__": # Check if sentencepiece is installed try: import sentencepiece print("โœ… SentencePiece found") except ImportError: print("โŒ SentencePiece not found") print("๐Ÿ”ง Install with: pip install sentencepiece protobuf") exit(1) interface = CompleteLocalFlux() interface.launch()