|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
import json |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
def ensure_env_installed(): |
|
|
try: |
|
|
import transformers |
|
|
import torchvision |
|
|
import diffusers |
|
|
import einops |
|
|
except ImportError: |
|
|
import subprocess |
|
|
import sys |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", |
|
|
"transformers==4.54.0", |
|
|
"torchvision==0.22.1", |
|
|
"diffusers==0.34.0", |
|
|
"einops==0.8.1"]) |
|
|
|
|
|
ensure_env_installed() |
|
|
|
|
|
|
|
|
model_zoo = { |
|
|
"imuru_small": { |
|
|
"repo_id": "Ruian7P/imuru_small", |
|
|
}, |
|
|
"imuru_large": { |
|
|
"repo_id": "Ruian7P/imuru_large", |
|
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
model = None |
|
|
|
|
|
def load_model(model_name="imuru_large"): |
|
|
global model |
|
|
|
|
|
if model is None: |
|
|
print(f"Loading model {model_name}...") |
|
|
from transformers import AutoModel |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
model_zoo[model_name]["repo_id"], |
|
|
trust_remote_code=True |
|
|
) |
|
|
model.eval() |
|
|
print("β
Model loaded") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def load_examples(): |
|
|
"""Load example samples.""" |
|
|
examples = [] |
|
|
examples.append([ |
|
|
"sample/sample.png", "Ruian7P" |
|
|
]) |
|
|
return examples |
|
|
|
|
|
def process_image(img): |
|
|
from torchvision.transforms import functional as F |
|
|
img = img.convert("RGB") |
|
|
img = img.resize((img.width * 64 // img.height, 64)) |
|
|
img = F.to_tensor(img) |
|
|
img = F.normalize(img, [0.5], [0.5]) |
|
|
return img |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate_handwriting(style_image, gen_text, model_name="imuru_large"): |
|
|
"""Generate handwriting in the style of the input image.""" |
|
|
if not gen_text or gen_text.strip() == "": |
|
|
return None, "β Please provide text to generate" |
|
|
|
|
|
if style_image is None: |
|
|
return None, "β Please upload a style image" |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(style_image, np.ndarray): |
|
|
style_image = Image.fromarray(style_image) |
|
|
|
|
|
|
|
|
loaded_model = load_model(model_name) |
|
|
loaded_model.to("cuda") |
|
|
|
|
|
|
|
|
style_img = process_image(style_image).to("cuda") |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
result = loaded_model.generate( |
|
|
style_img=style_img, |
|
|
gen_text=gen_text, |
|
|
max_new_tokens=512 |
|
|
) |
|
|
|
|
|
return result, "β
Generation successful!" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
width: 100%; |
|
|
max-width: 1200px !important; |
|
|
margin: 0 auto !important; |
|
|
} |
|
|
.header-text { |
|
|
text-align: center; |
|
|
margin-bottom: 1rem; |
|
|
} |
|
|
.feature-box { |
|
|
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); |
|
|
border-radius: 10px; |
|
|
padding: 15px; |
|
|
margin: 10px 0; |
|
|
} |
|
|
footer { |
|
|
visibility: hidden; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="Imuru") as demo: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
|
<h1>π Imuru: Autoregressive Handwriting Generation</h1> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
model_selector = gr.Dropdown( |
|
|
label="π€ Select Model", |
|
|
choices=list(model_zoo.keys()), |
|
|
value="imuru_large", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
style_image_input = gr.Image( |
|
|
label="πΌοΈ Style Image", |
|
|
type="pil", |
|
|
height=200 |
|
|
) |
|
|
|
|
|
gen_text_input = gr.Textbox( |
|
|
label="βοΈ Text to Generate", |
|
|
placeholder="Enter the text you want to generate in the selected style", |
|
|
lines=2, |
|
|
value="Hello, I am Imuru!" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("πΆ Generate", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image( |
|
|
label="πΌοΈ Generated Output", |
|
|
type="pil", |
|
|
height=200 |
|
|
) |
|
|
|
|
|
status_text = gr.Textbox( |
|
|
label="π§ Status", |
|
|
lines=1, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
examples = load_examples() |
|
|
if examples: |
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[style_image_input, gen_text_input], |
|
|
label="π‘ Examples", |
|
|
examples_per_page=4 |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_handwriting, |
|
|
inputs=[style_image_input, gen_text_input, model_selector], |
|
|
outputs=[output_image, status_text] |
|
|
) |
|
|
|
|
|
gen_text_input.submit( |
|
|
fn=generate_handwriting, |
|
|
inputs=[style_image_input, gen_text_input, model_selector], |
|
|
outputs=[output_image, status_text] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π§ How to Use |
|
|
|
|
|
1. **Upload a style image**: A handwritten sample to extract style from |
|
|
2. **Type generation text**: The text you want to generate in the style of the image |
|
|
3. **Click Generate**: Imuru will create the handwritten text image for you! |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|