Frenchizer commited on
Commit
fff52d6
·
verified ·
1 Parent(s): 05df9a0

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +48 -48
inference.py CHANGED
@@ -1,49 +1,49 @@
1
- from fastapi import FastAPI
2
- import onnxruntime as ort
3
- from transformers import AutoTokenizer
4
- from pydantic import BaseModel
5
-
6
- app = FastAPI()
7
-
8
- # Load ONNX model and tokenizer
9
- MODEL_FILE = "model.onnx"
10
- session = ort.InferenceSession(MODEL_FILE)
11
- tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
12
-
13
- # Define input model
14
- class TranslationInput(BaseModel):
15
- input_text: str
16
-
17
- @app.post("/predict")
18
- async def predict(translation_input: TranslationInput):
19
- """
20
- Endpoint for inference.
21
- :param translation_input: Text input in English.
22
- :return: Translated text in French.
23
- """
24
- # Tokenize input text
25
- tokenized_input = tokenizer(
26
- translation_input.input_text,
27
- return_tensors="np",
28
- padding=True
29
- )
30
- input_ids = tokenized_input["input_ids"]
31
-
32
- # Perform inference
33
- outputs = session.run(
34
- None,
35
- {"input_ids": input_ids.astype("int64")}
36
- )
37
- translated_ids = outputs[0]
38
-
39
- # Decode output tokens
40
- translated_text = tokenizer.decode(
41
- translated_ids[0],
42
- skip_special_tokens=True
43
- )
44
-
45
- return {"translated_text": translated_text}
46
-
47
- @app.get("/")
48
- async def root():
49
  return {"message": "ONNX model deployed on Hugging Face Spaces!"}
 
1
+ from fastapi import FastAPI
2
+ import onnxruntime as ort
3
+ from transformers import AutoTokenizer
4
+ from pydantic import BaseModel
5
+
6
+ app = FastAPI()
7
+
8
+ # Load ONNX model and tokenizer
9
+ MODEL_FILE = "./model.onnx"
10
+ session = ort.InferenceSession(MODEL_FILE)
11
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
12
+
13
+ # Define input model
14
+ class TranslationInput(BaseModel):
15
+ input_text: str
16
+
17
+ @app.post("/predict")
18
+ async def predict(translation_input: TranslationInput):
19
+ """
20
+ Endpoint for inference.
21
+ :param translation_input: Text input in English.
22
+ :return: Translated text in French.
23
+ """
24
+ # Tokenize input text
25
+ tokenized_input = tokenizer(
26
+ translation_input.input_text,
27
+ return_tensors="np",
28
+ padding=True
29
+ )
30
+ input_ids = tokenized_input["input_ids"]
31
+
32
+ # Perform inference
33
+ outputs = session.run(
34
+ None,
35
+ {"input_ids": input_ids.astype("int64")}
36
+ )
37
+ translated_ids = outputs[0]
38
+
39
+ # Decode output tokens
40
+ translated_text = tokenizer.decode(
41
+ translated_ids[0],
42
+ skip_special_tokens=True
43
+ )
44
+
45
+ return {"translated_text": translated_text}
46
+
47
+ @app.get("/")
48
+ async def root():
49
  return {"message": "ONNX model deployed on Hugging Face Spaces!"}