belkacemm commited on
Commit
1cded6b
·
1 Parent(s): 637c40c

updated app.py for greedy

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  MODEL_ID = "Velkamez/tamazight-sp-bpe" # or "./model"
6
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(
8
  MODEL_ID,
9
  trust_remote_code=True
@@ -11,38 +12,43 @@ tokenizer = AutoTokenizer.from_pretrained(
11
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
- torch_dtype=torch.float32,
15
  trust_remote_code=True
16
  )
17
  model.eval()
18
 
19
 
 
20
  def generate_reply(message):
 
21
  inputs = tokenizer(message, return_tensors="pt")
22
 
23
  with torch.no_grad():
24
  output_ids = model.generate(
25
  **inputs,
26
  max_new_tokens=64,
27
- do_sample=False,
28
  pad_token_id=tokenizer.eos_token_id
29
  )
30
 
31
- decoded = tokenizer.decode(
32
- output_ids[0],
33
- skip_special_tokens=True
34
- )
35
 
36
- if decoded.startswith(message):
37
- decoded = decoded[len(message):].strip()
 
38
 
39
- return decoded
40
 
41
 
 
42
  with gr.Blocks() as demo:
43
  gr.Markdown("# 🗣️ Tamazight LLM Demo")
44
 
45
- msg = gr.Textbox(label="Message")
 
 
 
46
  out = gr.Textbox(label="Model output")
47
 
48
  msg.submit(generate_reply, msg, out)
 
4
 
5
  MODEL_ID = "Velkamez/tamazight-sp-bpe" # or "./model"
6
 
7
+ # ---- Load tokenizer & model ----
8
  tokenizer = AutoTokenizer.from_pretrained(
9
  MODEL_ID,
10
  trust_remote_code=True
 
12
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_ID,
 
15
  trust_remote_code=True
16
  )
17
  model.eval()
18
 
19
 
20
+ # ---- Generation function ----
21
  def generate_reply(message):
22
+ # Encode input
23
  inputs = tokenizer(message, return_tensors="pt")
24
 
25
  with torch.no_grad():
26
  output_ids = model.generate(
27
  **inputs,
28
  max_new_tokens=64,
29
+ do_sample=False, # greedy
30
  pad_token_id=tokenizer.eos_token_id
31
  )
32
 
33
+ # 🔑 remove prompt tokens
34
+ input_len = inputs["input_ids"].shape[1]
35
+ new_token_ids = output_ids[0][input_len:]
 
36
 
37
+ # 🔑 tokenizer-aware decoding (fixes "az ul")
38
+ tokens = tokenizer.convert_ids_to_tokens(new_token_ids)
39
+ decoded = tokenizer.convert_tokens_to_string(tokens)
40
 
41
+ return decoded.strip()
42
 
43
 
44
+ # ---- UI ----
45
  with gr.Blocks() as demo:
46
  gr.Markdown("# 🗣️ Tamazight LLM Demo")
47
 
48
+ msg = gr.Textbox(
49
+ label="Message",
50
+ placeholder="Write something (e.g. azul)"
51
+ )
52
  out = gr.Textbox(label="Model output")
53
 
54
  msg.submit(generate_reply, msg, out)