Frenchizer commited on
Commit
f459388
·
verified ·
1 Parent(s): fff52d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -18
app.py CHANGED
@@ -1,25 +1,82 @@
1
  import gradio as gr
 
2
  import onnxruntime as ort
3
- import json
 
 
 
 
4
 
5
- # Update MODEL_FILE path if the ONNX model is hosted elsewhere
6
- MODEL_FILE = "https://huggingface.co/Frenchizer/model_1/resolve/main/model.onnx"
 
 
 
7
  session = ort.InferenceSession(MODEL_FILE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Define your translation function
10
- def translate(text: str, target_language: str):
11
- inputs = {"text": text, "target_language": target_language}
12
- ort_inputs = {session.get_inputs()[0].name: [json.dumps(inputs)]}
13
- outputs = session.run(None, ort_inputs)
14
- return outputs[0][0]
15
-
16
- # Create Gradio Interface
17
- interface = gr.Interface(
18
- fn=translate,
19
- inputs=["text", "text"], # Input boxes for source text and target language
20
- outputs="text", # Output box for translated text
21
- title="Frenchizer Translator",
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
- # Launch the app
25
- interface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from fastapi import FastAPI
3
  import onnxruntime as ort
4
+ from transformers import AutoTokenizer
5
+ from pydantic import BaseModel
6
+ import numpy as np
7
+ import uvicorn
8
+ from fastapi.responses import HTMLResponse
9
 
10
+ # Initialize FastAPI app
11
+ app = FastAPI()
12
+
13
+ # Load ONNX model and tokenizer
14
+ MODEL_FILE = "./model.onnx"
15
  session = ort.InferenceSession(MODEL_FILE)
16
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
17
+
18
+ # Define input model
19
+ class TranslationInput(BaseModel):
20
+ input_text: str
21
+
22
+ # FastAPI endpoint for model prediction
23
+ @app.post("/predict")
24
+ async def predict(translation_input: TranslationInput):
25
+ """
26
+ Endpoint for inference.
27
+ :param translation_input: Text input (e.g., in English).
28
+ :return: Translated text (e.g., in French).
29
+ """
30
+ # Tokenize input text
31
+ tokenized_input = tokenizer(
32
+ translation_input.input_text,
33
+ return_tensors="np",
34
+ padding=True
35
+ )
36
+ input_ids = tokenized_input["input_ids"]
37
+
38
+ # Perform inference with ONNX model
39
+ outputs = session.run(
40
+ None,
41
+ {"input_ids": input_ids.astype("int64")}
42
+ )
43
+
44
+ # Decode the output to get translated text
45
+ translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
46
+
47
+ return {"translated_text": translated_text}
48
 
49
+ # Root endpoint (optional)
50
+ @app.get("/")
51
+ def read_root():
52
+ return {"message": "ONNX model deployed with FastAPI!"}
53
+
54
+ # Gradio interface function (frontend)
55
+ def translate_text(input_text: str):
56
+ # Tokenize input text
57
+ tokenized_input = tokenizer(input_text, return_tensors="np", padding=True)
58
+ input_ids = tokenized_input["input_ids"]
59
+
60
+ # Perform inference with ONNX model
61
+ outputs = session.run(None, {"input_ids": input_ids.astype("int64")})
62
+ translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
63
+
64
+ return translated_text
65
+
66
+ # Create Gradio interface
67
+ gradio_interface = gr.Interface(
68
+ fn=translate_text,
69
+ inputs="text",
70
+ outputs="text",
71
+ title="English to French Translator",
72
+ description="A simple translator using a pre-trained ONNX model"
73
  )
74
 
75
+ # FastAPI endpoint for Gradio app (to render in the browser)
76
+ @app.get("/gradio")
77
+ async def gradio_ui():
78
+ return HTMLResponse(gradio_interface.launch(inline=True))
79
+
80
+ # Run FastAPI with Uvicorn
81
+ if __name__ == "__main__":
82
+ uvicorn.run(app, host="0.0.0.0", port=8000)