keysun89 commited on
Commit
4336dcc
·
verified ·
1 Parent(s): c1fabeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -113
app.py CHANGED
@@ -13,76 +13,59 @@ from unet import UNetModel
13
  from feature_extractor import Mixed_Encoder
14
 
15
  # ==========================================
16
- # 1. SETUP & DEVICE CONFIG
17
  # ==========================================
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- HINDI_VOCAB = [
21
- "अ", "आ", "इ", "ई", "उ", "ऊ", "ऋ", "ए", "ऐ", "ओ", "औ",
22
- "क", "ख", "ग", "घ", "ङ", "च", "छ", "ज", "झ", "ञ",
23
- "ट", "ठ", "ड", "ढ", "ण", "त", "थ", "द", "ध", "न",
24
- "प", "फ", "ब", "भ", "म", "य", "र", "ल", "व", "श",
25
- "ष", "स", "ह"
26
- ]
27
-
28
  # ==========================================
29
- # 2. INITIALIZE MODELS (Exact Architectural Match)
30
  # ==========================================
31
- print(f"🚀 Booting DiffusionPen on {DEVICE}...")
32
-
33
- # A. Style Encoder
34
- style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
35
- style_weights = torch.load("mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE)
36
- clean_style_dict = OrderedDict([(k.replace("module.", ""), v) for k, v in style_weights.items()])
37
- style_encoder.load_state_dict(clean_style_dict)
38
- style_encoder.eval()
39
-
40
- # B. Text Encoder (Canine)
41
- tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
42
- text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
43
-
44
- # C. VAE
45
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
46
-
47
- # D. UNet (Matched to your NLTM training)
48
- unet = UNetModel(
49
- image_size=(64, 256),
50
- in_channels=4,
51
- model_channels=320,
52
- out_channels=4,
53
- num_res_blocks=1,
54
- attention_resolutions=[4, 2, 1],
55
- channel_mult=[1, 1, 1, 1],
56
- context_dim=320
57
- ).to(DEVICE)
58
-
59
- # E. Super-Loader for ema_ckpt.pt
60
- full_checkpoint = torch.load("ema_ckpt.pt", map_location=DEVICE)
61
- clean_unet_dict = OrderedDict()
62
- clean_text_dict = OrderedDict()
63
-
64
- for k, v in full_checkpoint.items():
65
- clean_key = k.replace("module.", "")
66
- if "text_encoder." in clean_key:
67
- clean_text_dict[clean_key.split("text_encoder.")[-1]] = v
68
- else:
69
- clean_unet_dict[clean_key] = v
70
-
71
- unet.load_state_dict(clean_unet_dict, strict=False)
72
- try:
73
- text_encoder.load_state_dict(clean_text_dict, strict=False)
74
- except:
75
- pass # Fallback to base Canine if keys mismatch
76
 
77
- unet.eval()
78
- text_encoder.eval()
79
-
80
- scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # ==========================================
83
  # 3. INFERENCE ENGINE
84
  # ==========================================
85
- style_transform = transforms.Compose([
86
  transforms.Resize((224, 224)),
87
  transforms.ToTensor(),
88
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
@@ -90,63 +73,49 @@ style_transform = transforms.Compose([
90
 
91
  def predict(hindi_text, s1, s2):
92
  if not hindi_text: return None
93
-
94
- with torch.no_grad():
95
- # 1. Two-Shot Style Averaging
96
- style_inputs = [img for img in [s1, s2] if img is not None]
97
- if not style_inputs: return None
98
-
99
- all_style_vectors = []
100
- for img in style_inputs:
101
- img_t = style_transform(img).unsqueeze(0).to(DEVICE)
102
- _, feat = style_encoder(img_t)
103
- all_style_vectors.append(feat)
104
-
105
- final_style_vec = torch.mean(torch.stack(all_style_vectors), dim=0)
106
-
107
- # 2. Text Conditioning (The "Dictionary Bug" Fix)
108
- text_inputs = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt").to(DEVICE)
109
- # We pass the encoded hidden state to the UNet
110
- context = text_encoder(**text_inputs).last_hidden_state
111
-
112
- # 3. Diffusion Sampling (The "CPU Speed" Fix)
113
- latents = torch.randn((1, 4, 8, 32)).to(DEVICE)
114
- # ⚠️ REDUCED TO 10 STEPS for CPU stability. Change to 50 if you have a GPU.
115
- scheduler.set_timesteps(10)
116
-
117
- for t in scheduler.timesteps:
118
- # We bypass the internal text_encoder call in unet.py and pass pre-computed context
119
- noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=context, style_extractor=final_style_vec)
120
- latents = scheduler.step(noise_pred, t, latents).prev_sample
121
-
122
- # 4. Decode
123
- latents = 1 / 0.18215 * latents
124
- image = vae.decode(latents).sample
125
- image = (image / 2 + 0.5).clamp(0, 1)
126
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
127
-
128
- return Image.fromarray((image * 255).astype(np.uint8))
129
 
130
  # ==========================================
131
- # 4. GRADIO UI (2-Style Layout)
132
  # ==========================================
133
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
- gr.Markdown("# 🖋️ DiffusionPen: 2-Shot Hindi Style Transfer")
135
- gr.Markdown("### Developed by Kishan Madlani | NIT Surat")
136
-
137
  with gr.Row():
138
  with gr.Column():
139
- text_box = gr.Textbox(label="Enter Hindi Text", placeholder="नमस्ते...")
140
- gr.Markdown("#### 📷 Upload 2 Style Samples")
141
- with gr.Row():
142
- img1 = gr.Image(type="pil", label="Sample 1 (Crop to few words)")
143
- img2 = gr.Image(type="pil", label="Sample 2 (Crop to few words)")
144
- btn = gr.Button("Generate Handwriting", variant="primary")
145
-
146
- with gr.Column():
147
- result_view = gr.Image(label="Output")
148
- gr.Markdown("**Note:** Using 10 inference steps for real-time CPU performance.")
149
-
150
- btn.click(fn=predict, inputs=[text_box, img1, img2], outputs=result_view)
151
 
152
  demo.launch()
 
13
  from feature_extractor import Mixed_Encoder
14
 
15
  # ==========================================
16
+ # 1. SETUP
17
  # ==========================================
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
 
 
 
 
 
 
 
20
  # ==========================================
21
+ # 2. MODEL LOADING (With Error Catching)
22
  # ==========================================
23
+ print(f"📦 Loading Super-Checkpoint on {DEVICE}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ try:
26
+ # A. VAE
27
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(DEVICE)
28
+
29
+ # B. Style Encoder
30
+ style_encoder = Mixed_Encoder(model_name='mobilenetv2_100', num_classes=300).to(DEVICE)
31
+ s_weights = torch.load("mixed_hindi_mobilenetv2_100.pth", map_location=DEVICE)
32
+ style_encoder.load_state_dict(OrderedDict([(k.replace("module.", ""), v) for k, v in s_weights.items()]))
33
+ style_encoder.eval()
34
+
35
+ # C. Text Encoder & Tokenizer
36
+ tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
37
+ text_encoder = CanineModel.from_pretrained("google/canine-c").to(DEVICE)
38
+
39
+ # D. UNet (1 ResBlock, 320 Context)
40
+ unet = UNetModel(
41
+ image_size=(64, 256), in_channels=4, model_channels=320, out_channels=4,
42
+ num_res_blocks=1, attention_resolutions=[4, 2, 1], channel_mult=[1, 1, 1, 1], context_dim=320
43
+ ).to(DEVICE)
44
+
45
+ # E. Super-Loader for ema_ckpt.pt
46
+ ckpt = torch.load("ema_ckpt.pt", map_location=DEVICE)
47
+ u_dict, t_dict = OrderedDict(), OrderedDict()
48
+ for k, v in ckpt.items():
49
+ k = k.replace("module.", "")
50
+ if "text_encoder." in k: t_dict[k.split("text_encoder.")[-1]] = v
51
+ else: u_dict[k] = v
52
+
53
+ unet.load_state_dict(u_dict, strict=False)
54
+ try: text_encoder.load_state_dict(t_dict, strict=False)
55
+ except: print("⚠️ Using base Canine weights.")
56
+
57
+ unet.eval()
58
+ text_encoder.eval()
59
+ scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
60
+ print("✅ All models loaded perfectly!")
61
+
62
+ except Exception as e:
63
+ print(f"❌ CRITICAL LOAD ERROR: {e}")
64
 
65
  # ==========================================
66
  # 3. INFERENCE ENGINE
67
  # ==========================================
68
+ st_trans = transforms.Compose([
69
  transforms.Resize((224, 224)),
70
  transforms.ToTensor(),
71
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 
73
 
74
  def predict(hindi_text, s1, s2):
75
  if not hindi_text: return None
76
+ try:
77
+ with torch.no_grad():
78
+ # 1. Style
79
+ imgs = [i for i in [s1, s2] if i is not None]
80
+ if not imgs: return None
81
+ feats = [style_encoder(st_trans(i).unsqueeze(0).to(DEVICE))[1] for i in imgs]
82
+ style_vec = torch.mean(torch.stack(feats), dim=0)
83
+
84
+ # 2. Text (FIXED: We pass raw IDs so unet.py can handle the encoding)
85
+ t_in = tokenizer(hindi_text, padding="max_length", max_length=128, return_tensors="pt").to(DEVICE)
86
+
87
+ # 3. Diffusion (10 steps for CPU speed)
88
+ latents = torch.randn((1, 4, 8, 32)).to(DEVICE)
89
+ scheduler.set_timesteps(10)
90
+
91
+ for t in scheduler.timesteps:
92
+ # IMPORTANT: Pass the dictionary t_in, NOT a pre-computed tensor
93
+ # This matches your unet.py logic: context = self.text_encoder(**context)
94
+ noise_pred = unet(latents, t.unsqueeze(0).to(DEVICE), context=t_in, style_extractor=style_vector)
95
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
96
+
97
+ # 4. Decode
98
+ latents = 1 / 0.18215 * latents
99
+ img = vae.decode(latents).sample
100
+ img = (img / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0]
101
+ return Image.fromarray((img * 255).astype(np.uint8))
102
+
103
+ except Exception as e:
104
+ print(f"❌ RUNTIME ERROR: {e}")
105
+ return None
 
 
 
 
 
 
106
 
107
  # ==========================================
108
+ # 4. UI
109
  # ==========================================
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown("# 🖋️ DiffusionPen (NIT Surat)")
 
 
112
  with gr.Row():
113
  with gr.Column():
114
+ txt = gr.Textbox(label="Hindi Text")
115
+ im1 = gr.Image(type="pil", label="Style 1")
116
+ im2 = gr.Image(type="pil", label="Style 2")
117
+ btn = gr.Button("Generate")
118
+ out = gr.Image(label="Result")
119
+ btn.click(predict, [txt, im1, im2], out)
 
 
 
 
 
 
120
 
121
  demo.launch()