Frenchizer commited on
Commit
ea68300
·
verified ·
1 Parent(s): bb64441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -24
app.py CHANGED
@@ -13,7 +13,7 @@ def gradio_predict(input_text):
13
  tokenized_input = tokenizer(
14
  input_text,
15
  return_tensors="np",
16
- padding='max_length',
17
  truncation=True,
18
  max_length=512
19
  )
@@ -22,38 +22,57 @@ def gradio_predict(input_text):
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
- # Initialize decoder_input_ids
26
- decoder_input_ids = np.zeros((1, 512), dtype=np.int64)
27
- decoder_input_ids[:, 0] = tokenizer.bos_token_id or tokenizer.pad_token_id
28
 
29
- # Run inference
30
- outputs = session.run(
31
- None,
32
- {
33
- "input_ids": input_ids,
34
- "attention_mask": attention_mask,
35
- "decoder_input_ids": decoder_input_ids
36
- }
37
- )
38
-
39
- # Process logits to get token ids
40
- logits = outputs[0] # Shape: (1, 512, vocab_size)
41
- token_ids = np.argmax(logits, axis=-1)[0] # Get token ids for first sequence
42
 
43
- # Find where the sequence ends (pad token or eos token)
44
- eos_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id
45
- end_idx = np.where(token_ids == eos_token_id)[0]
46
- if len(end_idx) > 0:
47
- token_ids = token_ids[:end_idx[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Decode output
50
- translated_text = tokenizer.decode(token_ids, skip_special_tokens=True)
51
  return translated_text
52
 
53
  except Exception as e:
54
  print(f"Detailed error: {str(e)}")
55
  return f"Error during translation: {str(e)}"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Gradio interface for the web app
58
  gr.Interface(
59
  fn=gradio_predict,
 
13
  tokenized_input = tokenizer(
14
  input_text,
15
  return_tensors="np",
16
+ padding=True,
17
  truncation=True,
18
  max_length=512
19
  )
 
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
+ # Create proper decoder_input_ids for autoregressive generation
26
+ decoder_input_ids = np.array([[tokenizer.bos_token_id]], dtype=np.int64)
 
27
 
28
+ generated_ids = []
29
+ max_length = 128 # Maximum length of translation
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Autoregressive generation
32
+ for _ in range(max_length):
33
+ outputs = session.run(
34
+ None,
35
+ {
36
+ "input_ids": input_ids,
37
+ "attention_mask": attention_mask,
38
+ "decoder_input_ids": decoder_input_ids
39
+ }
40
+ )
41
+
42
+ # Get the next token prediction
43
+ next_token_logits = outputs[0][0, -1, :]
44
+ next_token = np.argmax(next_token_logits)
45
+
46
+ # Stop if we hit the EOS token
47
+ if next_token == tokenizer.eos_token_id:
48
+ break
49
+
50
+ # Append the predicted token
51
+ generated_ids.append(next_token)
52
+
53
+ # Update decoder_input_ids for next iteration
54
+ decoder_input_ids = np.array([[tokenizer.bos_token_id] + generated_ids], dtype=np.int64)
55
 
56
+ # Decode the generated sequence
57
+ translated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
58
  return translated_text
59
 
60
  except Exception as e:
61
  print(f"Detailed error: {str(e)}")
62
  return f"Error during translation: {str(e)}"
63
 
64
+ # Create and launch the interface
65
+ demo = gr.Interface(
66
+ fn=gradio_predict,
67
+ inputs=gr.Textbox(label="English text"),
68
+ outputs=gr.Textbox(label="French translation"),
69
+ title="English to French Translator",
70
+ description="Enter English text to translate to French"
71
+ )
72
+
73
+ if __name__ == "__main__":
74
+ demo.launch()
75
+
76
  # Gradio interface for the web app
77
  gr.Interface(
78
  fn=gradio_predict,