File size: 5,833 Bytes
59e62e2 dfd0c89 5928093 cd89c5b 5928093 cd89c5b 5928093 cd89c5b 5928093 cd89c5b 59e62e2 dfd0c89 2bdefe0 dfd0c89 bd2e01e 2bdefe0 dfd0c89 59e62e2 dfd0c89 bd2e01e dfd0c89 8a3954f dfd0c89 2bdefe0 dfd0c89 bd2e01e dfd0c89 0e62a79 dfd0c89 54ebe28 dfd0c89 54ebe28 dfd0c89 bd2e01e dfd0c89 8a3954f dfd0c89 8a3954f dfd0c89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import gradio as gr
import spaces
import torch
import json
from pathlib import Path
from PIL import Image
import numpy as np
def ensure_env_installed():
try:
import transformers
import torchvision
import diffusers
import einops
except ImportError:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install",
"transformers==4.54.0",
"torchvision==0.22.1",
"diffusers==0.34.0",
"einops==0.8.1"])
ensure_env_installed()
# Global model variable
model_zoo = {
"imuru_small": {
"repo_id": "Ruian7P/imuru_small",
},
"imuru_large": {
"repo_id": "Ruian7P/imuru_large",
},
# "emuru_t5_small": {
# "repo_id": "Ruian7P/emuru_result",
# "model_name": "emuru_t5_small_2e-5_ech5"
# }
}
model = None
def load_model(model_name="imuru_large"):
global model
if model is None:
print(f"Loading model {model_name}...")
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_zoo[model_name]["repo_id"],
trust_remote_code=True
)
model.eval()
print("β
Model loaded")
return model
def load_examples():
"""Load example samples."""
examples = []
examples.append([
"sample/sample.png", "Ruian7P"
])
return examples
def process_image(img):
from torchvision.transforms import functional as F
img = img.convert("RGB")
img = img.resize((img.width * 64 // img.height, 64))
img = F.to_tensor(img)
img = F.normalize(img, [0.5], [0.5])
return img
@spaces.GPU
def generate_handwriting(style_image, gen_text, model_name="imuru_large"):
"""Generate handwriting in the style of the input image."""
if not gen_text or gen_text.strip() == "":
return None, "β Please provide text to generate"
if style_image is None:
return None, "β Please upload a style image"
try:
# Convert numpy array to PIL Image if needed
if isinstance(style_image, np.ndarray):
style_image = Image.fromarray(style_image)
# Load and move model to GPU
loaded_model = load_model(model_name)
loaded_model.to("cuda")
# Preprocess style image
style_img = process_image(style_image).to("cuda")
# Generate
with torch.inference_mode():
result = loaded_model.generate(
style_img=style_img,
gen_text=gen_text,
max_new_tokens=512
)
return result, "β
Generation successful!"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"β Error: {str(e)}"
# Custom CSS for better styling
custom_css = """
.gradio-container {
width: 100%;
max-width: 1200px !important;
margin: 0 auto !important;
}
.header-text {
text-align: center;
margin-bottom: 1rem;
}
.feature-box {
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
}
footer {
visibility: hidden;
}
"""
# Build the interface with gr.Blocks for better customization
with gr.Blocks(css=custom_css, title="Imuru") as demo:
# Header
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>π Imuru: Autoregressive Handwriting Generation</h1>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
label="π€ Select Model",
choices=list(model_zoo.keys()),
value="imuru_large",
interactive=True
)
style_image_input = gr.Image(
label="πΌοΈ Style Image",
type="pil",
height=200
)
gen_text_input = gr.Textbox(
label="βοΈ Text to Generate",
placeholder="Enter the text you want to generate in the selected style",
lines=2,
value="Hello, I am Imuru!"
)
generate_btn = gr.Button("πΆ Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(
label="πΌοΈ Generated Output",
type="pil",
height=200
)
status_text = gr.Textbox(
label="π§ Status",
lines=1,
interactive=False
)
# Examples
examples = load_examples()
if examples:
gr.Examples(
examples=examples,
inputs=[style_image_input, gen_text_input],
label="π‘ Examples",
examples_per_page=4
)
# Connect events
generate_btn.click(
fn=generate_handwriting,
inputs=[style_image_input, gen_text_input, model_selector],
outputs=[output_image, status_text]
)
gen_text_input.submit(
fn=generate_handwriting,
inputs=[style_image_input, gen_text_input, model_selector],
outputs=[output_image, status_text]
)
# How to use section
gr.Markdown("""
---
### π§ How to Use
1. **Upload a style image**: A handwritten sample to extract style from
2. **Type generation text**: The text you want to generate in the style of the image
3. **Click Generate**: Imuru will create the handwritten text image for you!
""")
if __name__ == "__main__":
demo.launch()
|