File size: 5,833 Bytes
59e62e2
dfd0c89
 
 
 
 
 
5928093
cd89c5b
5928093
 
cd89c5b
 
 
5928093
cd89c5b
 
 
 
 
 
 
5928093
cd89c5b
59e62e2
dfd0c89
 
2bdefe0
 
dfd0c89
bd2e01e
 
 
2bdefe0
 
 
 
dfd0c89
59e62e2
dfd0c89
 
bd2e01e
dfd0c89
8a3954f
dfd0c89
 
 
 
 
2bdefe0
dfd0c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd2e01e
dfd0c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e62a79
dfd0c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54ebe28
dfd0c89
54ebe28
dfd0c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd2e01e
dfd0c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3954f
dfd0c89
 
 
 
 
8a3954f
dfd0c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import gradio as gr
import spaces
import torch
import json
from pathlib import Path
from PIL import Image
import numpy as np

def ensure_env_installed():
    try:
        import transformers
        import torchvision
        import diffusers
        import einops
    except ImportError:
        import subprocess
        import sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", 
                               "transformers==4.54.0",
                               "torchvision==0.22.1",
                               "diffusers==0.34.0",
                               "einops==0.8.1"])

ensure_env_installed()

# Global model variable
model_zoo = {
    "imuru_small": {
        "repo_id": "Ruian7P/imuru_small",
    },
    "imuru_large": {
        "repo_id": "Ruian7P/imuru_large",
    },
    # "emuru_t5_small": {
    #     "repo_id": "Ruian7P/emuru_result",
    #     "model_name": "emuru_t5_small_2e-5_ech5"
    # }
}

model = None

def load_model(model_name="imuru_large"):    
    global model

    if model is None:
        print(f"Loading model {model_name}...")
        from transformers import AutoModel
        
        model = AutoModel.from_pretrained(
            model_zoo[model_name]["repo_id"],
            trust_remote_code=True
        )
        model.eval()
        print("βœ… Model loaded")
    
    return model


def load_examples():
    """Load example samples."""
    examples = []
    examples.append([
        "sample/sample.png", "Ruian7P"
    ])
    return examples

def process_image(img):
    from torchvision.transforms import functional as F
    img = img.convert("RGB")
    img = img.resize((img.width * 64 // img.height, 64))
    img = F.to_tensor(img)
    img = F.normalize(img, [0.5], [0.5])
    return img


@spaces.GPU
def generate_handwriting(style_image, gen_text, model_name="imuru_large"):
    """Generate handwriting in the style of the input image."""
    if not gen_text or gen_text.strip() == "":
        return None, "❌ Please provide text to generate"
    
    if style_image is None:
        return None, "❌ Please upload a style image"
    
    try:
        # Convert numpy array to PIL Image if needed
        if isinstance(style_image, np.ndarray):
            style_image = Image.fromarray(style_image)
        
        # Load and move model to GPU
        loaded_model = load_model(model_name)
        loaded_model.to("cuda")
        
        # Preprocess style image
        style_img = process_image(style_image).to("cuda")
        
        # Generate
        with torch.inference_mode():
            result = loaded_model.generate(
                style_img=style_img,
                gen_text=gen_text,
                max_new_tokens=512
            )
        
        return result, "βœ… Generation successful!"
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"❌ Error: {str(e)}"


# Custom CSS for better styling
custom_css = """
.gradio-container {
    width: 100%;
    max-width: 1200px !important;
    margin: 0 auto !important;
}
.header-text {
    text-align: center;
    margin-bottom: 1rem;
}
.feature-box {
    background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
    border-radius: 10px;
    padding: 15px;
    margin: 10px 0;
}
footer {
    visibility: hidden;
}
"""

# Build the interface with gr.Blocks for better customization
with gr.Blocks(css=custom_css, title="Imuru") as demo:
    
    # Header
    gr.HTML("""
    <div style="text-align: center; margin-bottom: 20px;">
        <h1>🍎 Imuru: Autoregressive Handwriting Generation</h1>
    </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            model_selector = gr.Dropdown(
                label="πŸ€– Select Model",
                choices=list(model_zoo.keys()),
                value="imuru_large",
                interactive=True
            )

            style_image_input = gr.Image(
                label="πŸ–ΌοΈ Style Image",
                type="pil",
                height=200
            )
            
            gen_text_input = gr.Textbox(
                label="✍️ Text to Generate",
                placeholder="Enter the text you want to generate in the selected style",
                lines=2,
                value="Hello, I am Imuru!"
            )
            
            generate_btn = gr.Button("πŸ•Ά Generate", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            output_image = gr.Image(
                label="πŸ–ΌοΈ Generated Output",
                type="pil",
                height=200
            )
            
            status_text = gr.Textbox(
                label="🚧 Status",
                lines=1,
                interactive=False
            )
    
    # Examples
    examples = load_examples()
    if examples:
        gr.Examples(
            examples=examples,
            inputs=[style_image_input, gen_text_input],
            label="πŸ’‘ Examples",
            examples_per_page=4
        )
    
    # Connect events
    generate_btn.click(
        fn=generate_handwriting,
        inputs=[style_image_input, gen_text_input, model_selector],
        outputs=[output_image, status_text]
    )
    
    gen_text_input.submit(
        fn=generate_handwriting,
        inputs=[style_image_input, gen_text_input, model_selector],
        outputs=[output_image, status_text]
    )
    
    # How to use section
    gr.Markdown("""
    ---
    ### 🧠 How to Use
    
    1. **Upload a style image**: A handwritten sample to extract style from
    2. **Type generation text**: The text you want to generate in the style of the image
    3. **Click Generate**: Imuru will create the handwritten text image for you!
    """)


if __name__ == "__main__":
    demo.launch()