import gradio as gr from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from transformers import PreTrainedTokenizerFast import os # -------------------------------------- # LOAD TOKENIZER # -------------------------------------- TOKENIZER_JSON = "tokenizer_hindi_bpe_8k_stream/tokenizer.json" HF_DIR = "tokenizer_hindi_bpe_8k_stream/hf" if os.path.exists(HF_DIR): tokenizer = PreTrainedTokenizerFast.from_pretrained(HF_DIR) elif os.path.exists(TOKENIZER_JSON): tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_JSON) else: raise ValueError("Tokenizer not found!") print("Tokenizer loaded: vocab =", tokenizer.vocab_size) # -------------------------------------- # ENCODE / DECODE FUNCTIONS # -------------------------------------- def encode_text(text: str): """Basic encode: returns tokens + ids.""" enc = tokenizer(text, add_special_tokens=False) tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"]) # return tokens, enc["input_ids"] csv_ids = ",".join(str(x) for x in enc["input_ids"]) return tokens, csv_ids def encode_plus(text: str): enc = tokenizer( text, truncation=False, return_attention_mask=True, return_offsets_mapping=True, add_special_tokens=True ) enc["input_ids_csv"] = ",".join(str(x) for x in enc["input_ids"]) return enc def decode_ids(ids: str): """Decode from comma-separated IDs to text.""" try: arr = [int(x) for x in ids.split(",") if x.strip()] return tokenizer.decode(arr) except: return "❌ Invalid ID list" def batch_encode(text_list): """Batch encode multiple lines separated by newline.""" lines = [ln.strip() for ln in text_list.split("\n") if ln.strip()] enc = tokenizer(lines, add_special_tokens=False) out = [] for i, ids in enumerate(enc["input_ids"]): toks = tokenizer.convert_ids_to_tokens(ids) out.append({ "input": lines[i], "tokens": toks, "ids_csv": ",".join(str(x) for x in ids) }) return out # -------------------------------------- # FASTAPI REST BACKEND # -------------------------------------- api = FastAPI(title="Hindi Tokenizer API") api.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) @api.get("/") def home(): return { "message": "Hindi Tokenizer API", "vocab_size": tokenizer.vocab_size } @api.get("/tokenize") def tokenize_endpoint(text: str): enc = tokenizer(text, add_special_tokens=False) tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"]) return {"tokens": tokens, "ids": enc["input_ids"]} @api.get("/decode") def decode_endpoint(ids: str): try: arr = [int(x) for x in ids.split(",") if x.strip()] return {"text": tokenizer.decode(arr)} except: return {"error": "Invalid id list"} # -------------------------------------- # GRADIO FRONTEND # -------------------------------------- with gr.Blocks(title="Hindi Tokenizer") as demo: gr.Markdown("## 🔡 Hindi BPE Tokenizer — Encode / Decode / Batch") with gr.Tab("Encode"): text_in = gr.Textbox(label="Enter text") tokens_out = gr.JSON(label="Tokens") ids_out = gr.Textbox(label="Token IDs (CSV)") btn = gr.Button("Encode") btn.click(encode_text, text_in, [tokens_out, ids_out]) with gr.Tab("Encode+ (HF full)"): text2_in = gr.Textbox(label="Enter text (HF encode_plus)") enc_plus_out = gr.JSON(label="Output") btn2 = gr.Button("Run encode_plus") btn2.click(encode_plus, text2_in, enc_plus_out) with gr.Tab("Decode"): ids_in = gr.Textbox(label="Comma-separated token IDs") text_out = gr.Textbox(label="Decoded text") btn3 = gr.Button("Decode") btn3.click(decode_ids, ids_in, text_out) with gr.Tab("Batch Encode"): batch_in = gr.Textbox( label="Enter multiple lines (newline separated)", placeholder="Line 1\nLine 2\nLine 3" ) batch_out = gr.JSON(label="Batch output (CSV per line)") btn4 = gr.Button("Batch Encode") btn4.click(batch_encode, batch_in, batch_out) # Mount FastAPI + Gradio if "app" not in globals(): app = gr.mount_gradio_app(api, demo, path="/gradio") import os print('@@@@@@@@@@@@@',os.environ.get("PORT")) if __name__ == "__main__": print(os.environ.get("PORT")) demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)) )