ai-x9mhwmda / utils.py
BoobyBoobs's picture
Deploy Gradio app with multiple files
7d37ffa verified
import torch
import numpy as np
from PIL import Image
from typing import List, Tuple
import random
def set_seed(seed: int):
"""Set random seed for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def optimize_image_for_web(image: Image.Image, max_size: int = 1024, quality: int = 85) -> Image.Image:
"""Optimize image for web display"""
# Resize if too large
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = tuple(int(dim * ratio) for dim in image.size)
image = image.resize(new_size, Image.Resampling.LANCZOS)
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
return image
def create_image_grid(images: List[Image.Image], cols: int = 2) -> Image.Image:
"""Create a grid from multiple images"""
if not images:
return None
# Filter out None values
valid_images = [img for img in images if img is not None]
if not valid_images:
return None
# Calculate grid dimensions
rows = (len(valid_images) + cols - 1) // cols
width, height = valid_images[0].size
# Create grid
grid = Image.new('RGB', (cols * width, rows * height))
for i, img in enumerate(valid_images):
if img.size != (width, height):
img = img.resize((width, height), Image.Resampling.LANCZOS)
row = i // cols
col = i % cols
grid.paste(img, (col * width, row * height))
return grid
def validate_prompt(prompt: str) -> Tuple[bool, str]:
"""Validate and clean prompt"""
if not prompt or not prompt.strip():
return False, "Prompt cannot be empty"
prompt = prompt.strip()
# Check for potentially harmful content
harmful_words = ['violence', 'gore', 'nsfw', 'adult']
if any(word in prompt.lower() for word in harmful_words):
return False, "Prompt may contain inappropriate content"
# Check length
if len(prompt) > 500:
return False, "Prompt too long (max 500 characters)"
return True, prompt
def get_device_info() -> dict:
"""Get information about the available device"""
info = {
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'cuda_available': torch.cuda.is_available(),
}
if torch.cuda.is_available():
info['cuda_device_count'] = torch.cuda.device_count()
info['cuda_device_name'] = torch.cuda.get_device_name(0)
info['cuda_memory_allocated'] = torch.cuda.memory_allocated(0) / 1024**3 # GB
info['cuda_memory_reserved'] = torch.cuda.memory_reserved(0) / 1024**3 # GB
return info
def format_generation_info(seed: int, steps: int, guidance: float,
width: int, height: int, enhanced: bool = False) -> str:
"""Format generation information for display"""
info = f"✨ Generation Complete!\n"
info += f"🎲 Seed: {seed}\n"
info += f"πŸ”’ Steps: {steps}\n"
info += f"🎯 Guidance: {guidance}\n"
info += f"πŸ“ Size: {width}x{height}"
if enhanced:
info += "\n✨ Prompt Enhanced"
return info