Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,76 +13,59 @@ from unet import UNetModel
|
|
| 13 |
from feature_extractor import Mixed_Encoder
|
| 14 |
|
| 15 |
# ==========================================
|
| 16 |
-
# 1. SETUP
|
| 17 |
# ==========================================
|
| 18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
| 20 |
-
HINDI_VOCAB = [
|
| 21 |
-
"अ", "आ", "इ", "ई", "उ", "ऊ", "ऋ", "ए", "ऐ", "ओ", "औ",
|
| 22 |
-
"क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ",
|
| 23 |
-
"ट", "ठ", "ड", "ढ", "ण", "त", "थ", "द", "ध", "न",
|
| 24 |
-
"प", "फ", "ब", "भ", "म", "य", "र", "ल", "व", "श",
|
| 25 |
-
"ष", "स", "ह"
|
| 26 |
-
]
|
| 27 |
-
|
| 28 |
# ==========================================
|
| 29 |
-
# 2.
|
| 30 |
# ==========================================
|
| 31 |
-
print(f"
|
| 32 |
-
|
| 33 |
-
# A. Style Encoder
|
| 34 |
-
style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
|
| 35 |
-
style_weights = torch.load("mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE)
|
| 36 |
-
clean_style_dict = OrderedDict([(k.replace("module.", ""), v) for k, v in style_weights.items()])
|
| 37 |
-
style_encoder.load_state_dict(clean_style_dict)
|
| 38 |
-
style_encoder.eval()
|
| 39 |
-
|
| 40 |
-
# B. Text Encoder (Canine)
|
| 41 |
-
tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
|
| 42 |
-
text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
|
| 43 |
-
|
| 44 |
-
# C. VAE
|
| 45 |
-
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
|
| 46 |
-
|
| 47 |
-
# D. UNet (Matched to your NLTM training)
|
| 48 |
-
unet = UNetModel(
|
| 49 |
-
image_size=(64, 256),
|
| 50 |
-
in_channels=4,
|
| 51 |
-
model_channels=320,
|
| 52 |
-
out_channels=4,
|
| 53 |
-
num_res_blocks=1,
|
| 54 |
-
attention_resolutions=[4, 2, 1],
|
| 55 |
-
channel_mult=[1, 1, 1, 1],
|
| 56 |
-
context_dim=320
|
| 57 |
-
).to(DEVICE)
|
| 58 |
-
|
| 59 |
-
# E. Super-Loader for ema_ckpt.pt
|
| 60 |
-
full_checkpoint = torch.load("ema_ckpt.pt", map_location=DEVICE)
|
| 61 |
-
clean_unet_dict = OrderedDict()
|
| 62 |
-
clean_text_dict = OrderedDict()
|
| 63 |
-
|
| 64 |
-
for k, v in full_checkpoint.items():
|
| 65 |
-
clean_key = k.replace("module.", "")
|
| 66 |
-
if "text_encoder." in clean_key:
|
| 67 |
-
clean_text_dict[clean_key.split("text_encoder.")[-1]] = v
|
| 68 |
-
else:
|
| 69 |
-
clean_unet_dict[clean_key] = v
|
| 70 |
-
|
| 71 |
-
unet.load_state_dict(clean_unet_dict, strict=False)
|
| 72 |
-
try:
|
| 73 |
-
text_encoder.load_state_dict(clean_text_dict, strict=False)
|
| 74 |
-
except:
|
| 75 |
-
pass # Fallback to base Canine if keys mismatch
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# ==========================================
|
| 83 |
# 3. INFERENCE ENGINE
|
| 84 |
# ==========================================
|
| 85 |
-
|
| 86 |
transforms.Resize((224, 224)),
|
| 87 |
transforms.ToTensor(),
|
| 88 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
@@ -90,63 +73,49 @@ style_transform = transforms.Compose([
|
|
| 90 |
|
| 91 |
def predict(hindi_text, s1, s2):
|
| 92 |
if not hindi_text: return None
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
latents = 1 / 0.18215 * latents
|
| 124 |
-
image = vae.decode(latents).sample
|
| 125 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
| 126 |
-
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
| 127 |
-
|
| 128 |
-
return Image.fromarray((image * 255).astype(np.uint8))
|
| 129 |
|
| 130 |
# ==========================================
|
| 131 |
-
# 4.
|
| 132 |
# ==========================================
|
| 133 |
-
with gr.Blocks(
|
| 134 |
-
gr.Markdown("# 🖋️ DiffusionPen
|
| 135 |
-
gr.Markdown("### Developed by Kishan Madlani | NIT Surat")
|
| 136 |
-
|
| 137 |
with gr.Row():
|
| 138 |
with gr.Column():
|
| 139 |
-
|
| 140 |
-
gr.
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
with gr.Column():
|
| 147 |
-
result_view = gr.Image(label="Output")
|
| 148 |
-
gr.Markdown("**Note:** Using 10 inference steps for real-time CPU performance.")
|
| 149 |
-
|
| 150 |
-
btn.click(fn=predict, inputs=[text_box, img1, img2], outputs=result_view)
|
| 151 |
|
| 152 |
demo.launch()
|
|
|
|
| 13 |
from feature_extractor import Mixed_Encoder
|
| 14 |
|
| 15 |
# ==========================================
|
| 16 |
+
# 1. SETUP
|
| 17 |
# ==========================================
|
| 18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# ==========================================
|
| 21 |
+
# 2. MODEL LOADING (With Error Catching)
|
| 22 |
# ==========================================
|
| 23 |
+
print(f"📦 Loading Super-Checkpoint on {DEVICE}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
try:
|
| 26 |
+
# A. VAE
|
| 27 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
|
| 28 |
+
|
| 29 |
+
# B. Style Encoder
|
| 30 |
+
style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
|
| 31 |
+
s_weights = torch.load("mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE)
|
| 32 |
+
style_encoder.load_state_dict(OrderedDict([(k.replace("module.", ""), v) for k, v in s_weights.items()]))
|
| 33 |
+
style_encoder.eval()
|
| 34 |
+
|
| 35 |
+
# C. Text Encoder & Tokenizer
|
| 36 |
+
tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
|
| 37 |
+
text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
|
| 38 |
+
|
| 39 |
+
# D. UNet (1 ResBlock, 320 Context)
|
| 40 |
+
unet = UNetModel(
|
| 41 |
+
image_size=(64, 256), in_channels=4, model_channels=320, out_channels=4,
|
| 42 |
+
num_res_blocks=1, attention_resolutions=[4, 2, 1], channel_mult=[1, 1, 1, 1], context_dim=320
|
| 43 |
+
).to(DEVICE)
|
| 44 |
+
|
| 45 |
+
# E. Super-Loader for ema_ckpt.pt
|
| 46 |
+
ckpt = torch.load("ema_ckpt.pt", map_location=DEVICE)
|
| 47 |
+
u_dict, t_dict = OrderedDict(), OrderedDict()
|
| 48 |
+
for k, v in ckpt.items():
|
| 49 |
+
k = k.replace("module.", "")
|
| 50 |
+
if "text_encoder." in k: t_dict[k.split("text_encoder.")[-1]] = v
|
| 51 |
+
else: u_dict[k] = v
|
| 52 |
+
|
| 53 |
+
unet.load_state_dict(u_dict, strict=False)
|
| 54 |
+
try: text_encoder.load_state_dict(t_dict, strict=False)
|
| 55 |
+
except: print("⚠️ Using base Canine weights.")
|
| 56 |
+
|
| 57 |
+
unet.eval()
|
| 58 |
+
text_encoder.eval()
|
| 59 |
+
scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
| 60 |
+
print("✅ All models loaded perfectly!")
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"❌ CRITICAL LOAD ERROR: {e}")
|
| 64 |
|
| 65 |
# ==========================================
|
| 66 |
# 3. INFERENCE ENGINE
|
| 67 |
# ==========================================
|
| 68 |
+
st_trans = transforms.Compose([
|
| 69 |
transforms.Resize((224, 224)),
|
| 70 |
transforms.ToTensor(),
|
| 71 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
|
| 73 |
|
| 74 |
def predict(hindi_text, s1, s2):
|
| 75 |
if not hindi_text: return None
|
| 76 |
+
try:
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
# 1. Style
|
| 79 |
+
imgs = [i for i in [s1, s2] if i is not None]
|
| 80 |
+
if not imgs: return None
|
| 81 |
+
feats = [style_encoder(st_trans(i).unsqueeze(0).to(DEVICE))[1] for i in imgs]
|
| 82 |
+
style_vec = torch.mean(torch.stack(feats), dim=0)
|
| 83 |
+
|
| 84 |
+
# 2. Text (FIXED: We pass raw IDs so unet.py can handle the encoding)
|
| 85 |
+
t_in = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt").to(DEVICE)
|
| 86 |
+
|
| 87 |
+
# 3. Diffusion (10 steps for CPU speed)
|
| 88 |
+
latents = torch.randn((1, 4, 8, 32)).to(DEVICE)
|
| 89 |
+
scheduler.set_timesteps(10)
|
| 90 |
+
|
| 91 |
+
for t in scheduler.timesteps:
|
| 92 |
+
# IMPORTANT: Pass the dictionary t_in, NOT a pre-computed tensor
|
| 93 |
+
# This matches your unet.py logic: context = self.text_encoder(**context)
|
| 94 |
+
noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=t_in, style_extractor=style_vector)
|
| 95 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 96 |
+
|
| 97 |
+
# 4. Decode
|
| 98 |
+
latents = 1 / 0.18215 * latents
|
| 99 |
+
img = vae.decode(latents).sample
|
| 100 |
+
img = (img / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0]
|
| 101 |
+
return Image.fromarray((img * 255).astype(np.uint8))
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"❌ RUNTIME ERROR: {e}")
|
| 105 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# ==========================================
|
| 108 |
+
# 4. UI
|
| 109 |
# ==========================================
|
| 110 |
+
with gr.Blocks() as demo:
|
| 111 |
+
gr.Markdown("# 🖋️ DiffusionPen (NIT Surat)")
|
|
|
|
|
|
|
| 112 |
with gr.Row():
|
| 113 |
with gr.Column():
|
| 114 |
+
txt = gr.Textbox(label="Hindi Text")
|
| 115 |
+
im1 = gr.Image(type="pil", label="Style 1")
|
| 116 |
+
im2 = gr.Image(type="pil", label="Style 2")
|
| 117 |
+
btn = gr.Button("Generate")
|
| 118 |
+
out = gr.Image(label="Result")
|
| 119 |
+
btn.click(predict, [txt, im1, im2], out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
demo.launch()
|