Frenchizer commited on
Commit
006ed2f
·
verified ·
1 Parent(s): c086813

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -37
app.py CHANGED
@@ -4,10 +4,8 @@ import onnxruntime as ort
4
  from transformers import AutoTokenizer
5
  from pydantic import BaseModel
6
  import numpy as np
7
- import uvicorn
8
- from fastapi.responses import HTMLResponse
9
 
10
- # Initialize FastAPI app
11
  app = FastAPI()
12
 
13
  # Load ONNX model and tokenizer
@@ -15,7 +13,7 @@ MODEL_FILE = "./model.onnx"
15
  session = ort.InferenceSession(MODEL_FILE)
16
  tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
17
 
18
- # Define input model
19
  class TranslationInput(BaseModel):
20
  input_text: str
21
 
@@ -24,8 +22,8 @@ class TranslationInput(BaseModel):
24
  async def predict(translation_input: TranslationInput):
25
  """
26
  Endpoint for inference.
27
- :param translation_input: Text input (e.g., in English).
28
- :return: Translated text (e.g., in French).
29
  """
30
  # Tokenize input text
31
  tokenized_input = tokenizer(
@@ -34,49 +32,29 @@ async def predict(translation_input: TranslationInput):
34
  padding=True
35
  )
36
  input_ids = tokenized_input["input_ids"]
37
-
38
  # Perform inference with ONNX model
39
  outputs = session.run(
40
  None,
41
  {"input_ids": input_ids.astype("int64")}
42
  )
43
 
44
- # Decode the output to get translated text
45
  translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
46
-
47
  return {"translated_text": translated_text}
48
 
49
- # Root endpoint (optional)
50
- @app.get("/")
51
- def read_root():
52
- return {"message": "ONNX model deployed with FastAPI!"}
53
-
54
- # Gradio interface function (frontend)
55
- def translate_text(input_text: str):
56
- # Tokenize input text
57
- tokenized_input = tokenizer(input_text, return_tensors="np", padding=True)
58
- input_ids = tokenized_input["input_ids"]
59
-
60
- # Perform inference with ONNX model
61
- outputs = session.run(None, {"input_ids": input_ids.astype("int64")})
62
- translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
63
-
64
- return translated_text
65
 
66
- # Create Gradio interface
67
  gradio_interface = gr.Interface(
68
- fn=translate_text,
69
  inputs="text",
70
  outputs="text",
71
- title="English to French Translator",
72
- description="A simple translator using a pre-trained ONNX model"
73
  )
74
 
75
- # FastAPI endpoint for Gradio app (to render in the browser)
76
- @app.get("/gradio")
77
- async def gradio_ui():
78
- return HTMLResponse(gradio_interface.launch(inline=True))
79
-
80
- # Run FastAPI with Uvicorn
81
- if __name__ == "__main__":
82
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
4
  from transformers import AutoTokenizer
5
  from pydantic import BaseModel
6
  import numpy as np
 
 
7
 
8
+ # Initialize FastAPI and Gradio
9
  app = FastAPI()
10
 
11
  # Load ONNX model and tokenizer
 
13
  session = ort.InferenceSession(MODEL_FILE)
14
  tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
15
 
16
+ # Define input model for FastAPI
17
  class TranslationInput(BaseModel):
18
  input_text: str
19
 
 
22
  async def predict(translation_input: TranslationInput):
23
  """
24
  Endpoint for inference.
25
+ :param translation_input: Text input in English.
26
+ :return: Translated text in French.
27
  """
28
  # Tokenize input text
29
  tokenized_input = tokenizer(
 
32
  padding=True
33
  )
34
  input_ids = tokenized_input["input_ids"]
35
+
36
  # Perform inference with ONNX model
37
  outputs = session.run(
38
  None,
39
  {"input_ids": input_ids.astype("int64")}
40
  )
41
 
42
+ # Decode output and return translated text
43
  translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
 
44
  return {"translated_text": translated_text}
45
 
46
+ # Gradio Interface
47
+ def gradio_predict(input_text):
48
+ response = predict(TranslationInput(input_text=input_text))
49
+ return response["translated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Gradio interface for the web app
52
  gradio_interface = gr.Interface(
53
+ fn=gradio_predict,
54
  inputs="text",
55
  outputs="text",
56
+ live=True
 
57
  )
58
 
59
+ # Launch Gradio app
60
+ gradio_interface.launch(inline=True, server_name="0.0.0.0", server_port=8000)