| import gradio as gr |
| import torch |
| from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, LCMScheduler |
| import time |
| import json |
| import hashlib |
| import base64 |
| import os |
| import io |
| import traceback |
| import threading |
| from datetime import datetime, timedelta |
| from PIL import Image |
|
|
| MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7" |
| DEVICE = "cpu" |
| CACHE_DIR = "/tmp/generated_images" |
| CACHE_TTL_SECONDS = 86400 |
| GENERATION_TIMEOUT = 120 |
|
|
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| print("Loading Advanced Quality Stack (V2) - MicroCore Studio Edition...") |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
| pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) |
| pipe_i2i = StableDiffusionImg2ImgPipeline(**pipe.components) |
| pipe.enable_attention_slicing(1) |
| pipe_i2i.enable_attention_slicing(1) |
|
|
| print("Models loaded and FreeU enabled.") |
|
|
| STYLES = { |
| "None": "{prompt}", |
| "Cinematic": "cinematic photo, {prompt}, highly detailed, 8k, sharp focus, dramatic lighting, film grain", |
| "Photorealistic": "professional photography, {prompt}, realistic, 35mm lens, f/1.8, bokeh, masterpiece, ultra-detailed", |
| "Digital Art": "digital painting, {prompt}, intricate detail, vibrant colors, fantasy art, smooth, sharp edges", |
| "Cyberpunk": "cyberpunk style, {prompt}, neon lights, rainy street, high contrast, futuristic, blade runner aesthetic", |
| "Anime": "anime style, {prompt}, hand-drawn, high resolution, vibrant, clean lines" |
| } |
|
|
| class CircuitBreaker: |
| def __init__(self, failure_threshold=3, reset_timeout=60): |
| self.failure_threshold = failure_threshold |
| self.reset_timeout = reset_timeout |
| self.failure_count = 0 |
| self.last_failure_time = None |
| self.state = "CLOSED" |
| self.lock = threading.Lock() |
|
|
| def is_open(self): |
| with self.lock: |
| if self.state == "OPEN": |
| if time.time() - self.last_failure_time >= self.reset_timeout: |
| self.state = "HALF_OPEN" |
| print("[CircuitBreaker] Transitioning to HALF_OPEN") |
| return False |
| return True |
| return False |
|
|
| def record_success(self): |
| with self.lock: |
| self.failure_count = 0 |
| self.state = "CLOSED" |
|
|
| def record_failure(self): |
| with self.lock: |
| self.failure_count += 1 |
| self.last_failure_time = time.time() |
| if self.failure_count >= self.failure_threshold: |
| self.state = "OPEN" |
| print(f"[CircuitBreaker] OPEN after {self.failure_count} failures, pausing for {self.reset_timeout}s") |
|
|
| circuit_breaker = CircuitBreaker(failure_threshold=3, reset_timeout=60) |
|
|
| def compute_cache_key(payload_dict): |
| payload_str = json.dumps(payload_dict, sort_keys=True) |
| return hashlib.sha256(payload_str.encode()).hexdigest() |
|
|
| def get_cached_image(cache_key): |
| cache_path = os.path.join(CACHE_DIR, f"{cache_key}.png") |
| if not os.path.exists(cache_path): |
| return None |
| file_mtime = datetime.fromtimestamp(os.path.getmtime(cache_path)) |
| if datetime.now() - file_mtime > timedelta(seconds=CACHE_TTL_SECONDS): |
| try: |
| os.remove(cache_path) |
| except OSError: |
| pass |
| return None |
| try: |
| with open(cache_path, "rb") as f: |
| return f.read() |
| except IOError: |
| return None |
|
|
| def save_cached_image(cache_key, image_data): |
| cache_path = os.path.join(CACHE_DIR, f"{cache_key}.png") |
| try: |
| with open(cache_path, "wb") as f: |
| f.write(image_data) |
| except IOError as e: |
| print(f"[Cache] Warning: Could not save to {cache_path}: {e}") |
|
|
| def image_to_base64(image): |
| buf = io.BytesIO() |
| image.save(buf, format="PNG") |
| return base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
| def generate_advanced(prompt, negative_prompt, style, steps, guidance, polish_intensity): |
| if circuit_breaker.is_open(): |
| error_msg = json.dumps({"error": "Service temporarily unavailable due to high error rate. Please retry in 60s."}) |
| raise gr.Error(error_msg) |
|
|
| if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0: |
| error_msg = json.dumps({"error": "Invalid payload: 'prompt' is required and must be a non-empty string.", "code": 400}) |
| raise gr.Error(error_msg) |
|
|
| try: |
| steps = int(steps) if steps is not None else 8 |
| guidance = float(guidance) if guidance is not None else 1.5 |
| polish_intensity = float(polish_intensity) if polish_intensity is not None else 0.3 |
| except (ValueError, TypeError): |
| error_msg = json.dumps({"error": "Invalid payload: steps, guidance, and polish_intensity must be numeric.", "code": 400}) |
| raise gr.Error(error_msg) |
|
|
| cache_payload = { |
| "prompt": prompt, |
| "negative_prompt": negative_prompt or "", |
| "style": style, |
| "steps": steps, |
| "guidance": guidance, |
| "polish_intensity": polish_intensity |
| } |
| cache_key = compute_cache_key(cache_payload) |
|
|
| cached = get_cached_image(cache_key) |
| if cached is not None: |
| print(f"[Cache] HIT for key {cache_key[:16]}...") |
| cached_image = Image.open(io.BytesIO(cached)) |
| b64_data = base64.b64encode(cached).decode("utf-8") |
| return cached_image, f"CACHED (TTL: 24h) | Cache Key: {cache_key[:16]}..." |
|
|
| start_time = time.time() |
|
|
| try: |
| full_prompt = STYLES[style].format(prompt=prompt) |
| except KeyError: |
| error_msg = json.dumps({"error": f"Invalid payload: unknown style '{style}'. Valid styles: {list(STYLES.keys())}", "code": 400}) |
| raise gr.Error(error_msg) |
|
|
| print(f"Stage 1: Generating base image with {style} style...") |
| try: |
| base_image = pipe( |
| prompt=full_prompt, |
| negative_prompt=negative_prompt if negative_prompt else "blurry, low quality, distorted", |
| num_inference_steps=steps, |
| guidance_scale=guidance, |
| width=768, |
| height=768, |
| ).images[0] |
| except Exception as e: |
| circuit_breaker.record_failure() |
| print(f"[Pipeline] Stage 1 failed: {e}") |
| error_msg = json.dumps({"error": f"Generation failed in stage 1: {str(e)}", "code": 500}) |
| raise gr.Error(error_msg) |
|
|
| if polish_intensity > 0: |
| print(f"Stage 2: Applying Polish Pass (Intensity: {polish_intensity})...") |
| try: |
| refined_image = pipe_i2i( |
| prompt=full_prompt, |
| negative_prompt=negative_prompt, |
| image=base_image, |
| strength=polish_intensity, |
| num_inference_steps=max(1, int(steps / 2)), |
| guidance_scale=guidance, |
| ).images[0] |
| final_image = refined_image |
| except Exception as e: |
| print(f"[Pipeline] Stage 2 (polish) failed, returning base image: {e}") |
| final_image = base_image |
| else: |
| final_image = base_image |
|
|
| duration = round(time.time() - start_time, 2) |
| circuit_breaker.record_success() |
|
|
| try: |
| img_bytes = image_to_base64(final_image) |
| raw_bytes = base64.b64decode(img_bytes) |
| save_cached_image(cache_key, raw_bytes) |
| print(f"[Cache] SAVED key {cache_key[:16]}... ({len(raw_bytes)} bytes)") |
| except Exception as e: |
| print(f"[Cache] Warning: Failed to save cache: {e}") |
|
|
| return final_image, f"Quality Optimized Generation in {duration}s | Cache Key: {cache_key[:16]}..." |
|
|
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🚀 Advanced CPU Image Gen V2 — MicroCore Studio Edition") |
| gr.Markdown("Using **FreeU** + **Two-Stage Refinement** + **Caching** + **Circuit Breaker** for reliable API integration.") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(label="Prompt", placeholder="Describe your vision...", lines=3) |
| style = gr.Dropdown(choices=list(STYLES.keys()), value="Photorealistic", label="Style Engine (Auto-Boosting)") |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, ugly, low quality, deformed") |
| steps = gr.Slider(4, 12, value=8, step=1, label="Steps") |
| guidance = gr.Slider(1.0, 4.0, value=1.5, step=0.1, label="Guidance Scale") |
| polish = gr.Slider(0.0, 0.5, value=0.3, step=0.05, label="Polish Intensity (Refiner Pass)") |
|
|
| btn = gr.Button("🎨 Generate High Quality Image", variant="primary") |
|
|
| with gr.Column(): |
| output_image = gr.Image(label="V2 Advanced Output", format="png") |
| status = gr.Text(label="Engine Status") |
|
|
| btn.click( |
| fn=generate_advanced, |
| inputs=[prompt, negative_prompt, style, steps, guidance, polish], |
| outputs=[output_image, status] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|