khansagiffany commited on
Commit
f25a1bf
·
verified ·
1 Parent(s): 2c4ef41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+
5
+ print("Loading IndoBERT model...")
6
+ MODEL_NAME = "indobenchmark/indobert-base-p1"
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
+ model = AutoModel.from_pretrained(MODEL_NAME)
9
+ model.eval()
10
+ print("Model loaded!")
11
+
12
+ def mean_pooling(model_output, attention_mask):
13
+ token_embeddings = model_output[0]
14
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
15
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
16
+
17
+ def generate_embedding(text):
18
+ encoded_input = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
19
+
20
+ with torch.no_grad():
21
+ model_output = model(**encoded_input)
22
+
23
+ embedding = mean_pooling(model_output, encoded_input['attention_mask'])
24
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
25
+
26
+ return embedding[0].numpy().tolist()
27
+
28
+ def embed_single(text):
29
+ """For Gradio interface - single text"""
30
+ if not text:
31
+ return {"error": "Text required"}
32
+
33
+ embedding = generate_embedding(text)
34
+ return {
35
+ "success": True,
36
+ "embedding": embedding,
37
+ "dimension": len(embedding)
38
+ }
39
+
40
+ def embed_batch(texts):
41
+ """For Gradio interface - batch texts"""
42
+ if not texts:
43
+ return {"error": "Texts required"}
44
+
45
+ text_list = [t.strip() for t in texts.split('\n') if t.strip()]
46
+ embeddings = [generate_embedding(text) for text in text_list]
47
+
48
+ return {
49
+ "success": True,
50
+ "embeddings": embeddings,
51
+ "count": len(embeddings),
52
+ "dimension": len(embeddings[0]) if embeddings else 0
53
+ }
54
+
55
+ # Gradio Interface
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# 🇮🇩 IndoBERT Embedding API")
58
+
59
+ with gr.Tab("Single"):
60
+ input_single = gr.Textbox(label="Text", lines=3)
61
+ btn_single = gr.Button("Generate")
62
+ output_single = gr.JSON(label="Result")
63
+ btn_single.click(embed_single, inputs=input_single, outputs=output_single)
64
+
65
+ with gr.Tab("Batch"):
66
+ input_batch = gr.Textbox(label="Texts (one per line)", lines=10)
67
+ btn_batch = gr.Button("Generate Batch")
68
+ output_batch = gr.JSON(label="Result")
69
+ btn_batch.click(embed_batch, inputs=input_batch, outputs=output_batch)
70
+
71
+ demo.launch()