HindiTokenizer / app.py
Rahul2020's picture
update app
151fe6c verified
import gradio as gr
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from transformers import PreTrainedTokenizerFast
import os
# --------------------------------------
# LOAD TOKENIZER
# --------------------------------------
TOKENIZER_JSON = "tokenizer_hindi_bpe_8k_stream/tokenizer.json"
HF_DIR = "tokenizer_hindi_bpe_8k_stream/hf"
if os.path.exists(HF_DIR):
tokenizer = PreTrainedTokenizerFast.from_pretrained(HF_DIR)
elif os.path.exists(TOKENIZER_JSON):
tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_JSON)
else:
raise ValueError("Tokenizer not found!")
print("Tokenizer loaded: vocab =", tokenizer.vocab_size)
# --------------------------------------
# ENCODE / DECODE FUNCTIONS
# --------------------------------------
def encode_text(text: str):
"""Basic encode: returns tokens + ids."""
enc = tokenizer(text, add_special_tokens=False)
tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
# return tokens, enc["input_ids"]
csv_ids = ",".join(str(x) for x in enc["input_ids"])
return tokens, csv_ids
def encode_plus(text: str):
enc = tokenizer(
text,
truncation=False,
return_attention_mask=True,
return_offsets_mapping=True,
add_special_tokens=True
)
enc["input_ids_csv"] = ",".join(str(x) for x in enc["input_ids"])
return enc
def decode_ids(ids: str):
"""Decode from comma-separated IDs to text."""
try:
arr = [int(x) for x in ids.split(",") if x.strip()]
return tokenizer.decode(arr)
except:
return "❌ Invalid ID list"
def batch_encode(text_list):
"""Batch encode multiple lines separated by newline."""
lines = [ln.strip() for ln in text_list.split("\n") if ln.strip()]
enc = tokenizer(lines, add_special_tokens=False)
out = []
for i, ids in enumerate(enc["input_ids"]):
toks = tokenizer.convert_ids_to_tokens(ids)
out.append({
"input": lines[i],
"tokens": toks,
"ids_csv": ",".join(str(x) for x in ids)
})
return out
# --------------------------------------
# FASTAPI REST BACKEND
# --------------------------------------
api = FastAPI(title="Hindi Tokenizer API")
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
@api.get("/")
def home():
return {
"message": "Hindi Tokenizer API",
"vocab_size": tokenizer.vocab_size
}
@api.get("/tokenize")
def tokenize_endpoint(text: str):
enc = tokenizer(text, add_special_tokens=False)
tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"])
return {"tokens": tokens, "ids": enc["input_ids"]}
@api.get("/decode")
def decode_endpoint(ids: str):
try:
arr = [int(x) for x in ids.split(",") if x.strip()]
return {"text": tokenizer.decode(arr)}
except:
return {"error": "Invalid id list"}
# --------------------------------------
# GRADIO FRONTEND
# --------------------------------------
with gr.Blocks(title="Hindi Tokenizer") as demo:
gr.Markdown("## 🔡 Hindi BPE Tokenizer — Encode / Decode / Batch")
with gr.Tab("Encode"):
text_in = gr.Textbox(label="Enter text")
tokens_out = gr.JSON(label="Tokens")
ids_out = gr.Textbox(label="Token IDs (CSV)")
btn = gr.Button("Encode")
btn.click(encode_text, text_in, [tokens_out, ids_out])
with gr.Tab("Encode+ (HF full)"):
text2_in = gr.Textbox(label="Enter text (HF encode_plus)")
enc_plus_out = gr.JSON(label="Output")
btn2 = gr.Button("Run encode_plus")
btn2.click(encode_plus, text2_in, enc_plus_out)
with gr.Tab("Decode"):
ids_in = gr.Textbox(label="Comma-separated token IDs")
text_out = gr.Textbox(label="Decoded text")
btn3 = gr.Button("Decode")
btn3.click(decode_ids, ids_in, text_out)
with gr.Tab("Batch Encode"):
batch_in = gr.Textbox(
label="Enter multiple lines (newline separated)",
placeholder="Line 1\nLine 2\nLine 3"
)
batch_out = gr.JSON(label="Batch output (CSV per line)")
btn4 = gr.Button("Batch Encode")
btn4.click(batch_encode, batch_in, batch_out)
# Mount FastAPI + Gradio
if "app" not in globals():
app = gr.mount_gradio_app(api, demo, path="/gradio")
import os
print('@@@@@@@@@@@@@',os.environ.get("PORT"))
if __name__ == "__main__":
print(os.environ.get("PORT"))
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860))
)