import gradio as gr import spaces import torch import json from pathlib import Path from PIL import Image import numpy as np def ensure_env_installed(): try: import transformers import torchvision import diffusers import einops except ImportError: import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.54.0", "torchvision==0.22.1", "diffusers==0.34.0", "einops==0.8.1"]) ensure_env_installed() # Global model variable model_zoo = { "imuru_small": { "repo_id": "Ruian7P/imuru_small", }, "imuru_large": { "repo_id": "Ruian7P/imuru_large", }, # "emuru_t5_small": { # "repo_id": "Ruian7P/emuru_result", # "model_name": "emuru_t5_small_2e-5_ech5" # } } model = None def load_model(model_name="imuru_large"): global model if model is None: print(f"Loading model {model_name}...") from transformers import AutoModel model = AutoModel.from_pretrained( model_zoo[model_name]["repo_id"], trust_remote_code=True ) model.eval() print("✅ Model loaded") return model def load_examples(): """Load example samples.""" examples = [] examples.append([ "sample/sample.png", "Ruian7P" ]) return examples def process_image(img): from torchvision.transforms import functional as F img = img.convert("RGB") img = img.resize((img.width * 64 // img.height, 64)) img = F.to_tensor(img) img = F.normalize(img, [0.5], [0.5]) return img @spaces.GPU def generate_handwriting(style_image, gen_text, model_name="imuru_large"): """Generate handwriting in the style of the input image.""" if not gen_text or gen_text.strip() == "": return None, "❌ Please provide text to generate" if style_image is None: return None, "❌ Please upload a style image" try: # Convert numpy array to PIL Image if needed if isinstance(style_image, np.ndarray): style_image = Image.fromarray(style_image) # Load and move model to GPU loaded_model = load_model(model_name) loaded_model.to("cuda") # Preprocess style image style_img = process_image(style_image).to("cuda") # Generate with torch.inference_mode(): result = loaded_model.generate( style_img=style_img, gen_text=gen_text, max_new_tokens=512 ) return result, "✅ Generation successful!" except Exception as e: import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" # Custom CSS for better styling custom_css = """ .gradio-container { width: 100%; max-width: 1200px !important; margin: 0 auto !important; } .header-text { text-align: center; margin-bottom: 1rem; } .feature-box { background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); border-radius: 10px; padding: 15px; margin: 10px 0; } footer { visibility: hidden; } """ # Build the interface with gr.Blocks for better customization with gr.Blocks(css=custom_css, title="Imuru") as demo: # Header gr.HTML("""