harishwar017 commited on
Commit
c2ab97b
·
1 Parent(s): 9677cfe
Files changed (2) hide show
  1. app.py +23 -13
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,11 +3,13 @@ import re
3
  import torch
4
  import torch.nn as nn
5
  import gradio as gr
 
6
 
7
  ########################################
8
  # Model definitions (same as in notebook)
9
  ########################################
10
 
 
11
  class EncoderGRU(nn.Module):
12
  def __init__(self, input_dim, emb_dim, hid_dim, num_layers=1, dropout=0.1, pad_idx=0):
13
  super().__init__()
@@ -72,12 +74,27 @@ class Seq2Seq(nn.Module):
72
 
73
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
 
 
 
 
 
 
 
 
 
75
  # Load vocabularies
76
- with open("src_stoi.json", "r", encoding="utf-8") as f:
77
  src_stoi = json.load(f)
78
 
79
- with open("tgt_stoi.json", "r", encoding="utf-8") as f:
80
  tgt_stoi = json.load(f)
 
 
 
 
 
 
 
81
 
82
  # Build inverse mapping for target
83
  tgt_itos = {int(v): k for k, v in tgt_stoi.items()} # keys might be strings in JSON
@@ -127,13 +144,9 @@ model = Seq2Seq(
127
  device=device,
128
  ).to(device)
129
 
130
- state_dict = torch.load("best_hindi_roman_gru.pt", map_location=device)
131
- model.encoder.load_state_dict(
132
- {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}
133
- )
134
- model.decoder.load_state_dict(
135
- {k.replace("decoder.", ""): v for k, v in state_dict.items() if k.startswith("decoder.")}
136
- )
137
  model.eval()
138
 
139
 
@@ -234,10 +247,7 @@ demo = gr.Interface(
234
  inputs=gr.Textbox(lines=3, label="Hindi sentence"),
235
  outputs=gr.Textbox(lines=3, label="Romanized (Latin script)"),
236
  title="Hindi → Roman Transliteration (Char-level GRU)",
237
- description=(
238
- "Paste a Hindi sentence; the model splits it into words, "
239
- "applies a character-level GRU transliteration model, and rejoins the output."
240
- ),
241
  )
242
 
243
  if __name__ == "__main__":
 
3
  import torch
4
  import torch.nn as nn
5
  import gradio as gr
6
+ from huggingface_hub import hf_hub_download
7
 
8
  ########################################
9
  # Model definitions (same as in notebook)
10
  ########################################
11
 
12
+
13
  class EncoderGRU(nn.Module):
14
  def __init__(self, input_dim, emb_dim, hid_dim, num_layers=1, dropout=0.1, pad_idx=0):
15
  super().__init__()
 
74
 
75
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
 
77
+ # 🔴 CHANGE THIS: your actual model repo id
78
+ MODEL_REPO = "harishwar017/hindi-roman-gru"
79
+
80
+ # Download files from HF Hub into the Space’s local cache
81
+ src_json_path = hf_hub_download(repo_id=MODEL_REPO, filename="src_stoi.json")
82
+ tgt_json_path = hf_hub_download(repo_id=MODEL_REPO, filename="tgt_stoi.json")
83
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_hindi_roman_gru.pt")
84
+
85
  # Load vocabularies
86
+ with open(src_json_path, "r", encoding="utf-8") as f:
87
  src_stoi = json.load(f)
88
 
89
+ with open(tgt_json_path, "r", encoding="utf-8") as f:
90
  tgt_stoi = json.load(f)
91
+
92
+ # # Load vocabularies
93
+ # with open("src_stoi.json", "r", encoding="utf-8") as f:
94
+ # src_stoi = json.load(f)
95
+
96
+ # with open("tgt_stoi.json", "r", encoding="utf-8") as f:
97
+ # tgt_stoi = json.load(f)
98
 
99
  # Build inverse mapping for target
100
  tgt_itos = {int(v): k for k, v in tgt_stoi.items()} # keys might be strings in JSON
 
144
  device=device,
145
  ).to(device)
146
 
147
+ # Load weights that you saved from training: torch.save(model.state_dict(), "best_hindi_roman_gru.pt")
148
+ state_dict = torch.load(model_path, map_location=device)
149
+ model.load_state_dict(state_dict)
 
 
 
 
150
  model.eval()
151
 
152
 
 
247
  inputs=gr.Textbox(lines=3, label="Hindi sentence"),
248
  outputs=gr.Textbox(lines=3, label="Romanized (Latin script)"),
249
  title="Hindi → Roman Transliteration (Char-level GRU)",
250
+ description="Paste a Hindi sentence; the model splits it into words, transliterates each with a GRU, and rejoins the output.",
 
 
 
251
  )
252
 
253
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  torch
2
  gradio
 
 
1
  torch
2
  gradio
3
+ huggingface_hub