Abineshkumar77 commited on
Commit
d555bd3
ยท
1 Parent(s): 9bae10e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +60 -8
app.py CHANGED
@@ -1,13 +1,43 @@
1
  from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
 
4
  import time
 
5
 
6
- # Load the tokenizer and model directly
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
8
- model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
9
 
10
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def preprocess_tweet(tweet: str) -> str:
13
  tweet_words = []
@@ -19,6 +49,22 @@ def preprocess_tweet(tweet: str) -> str:
19
  tweet_words.append(word)
20
  return " ".join(tweet_words)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @app.get("/")
23
  def home():
24
  return {"message": "Welcome to the sentiment analysis API"}
@@ -33,16 +79,22 @@ def analyze_sentiment(tweet: str):
33
 
34
  # Tokenize the input tweet
35
  inputs = tokenizer(tweet_proc, return_tensors="pt")
 
 
 
 
 
 
36
 
37
- # Perform the inference
38
- with torch.no_grad():
39
- outputs = model(**inputs)
40
 
41
  # Calculate the inference time
42
  inference_time = time.time() - start_time
43
 
44
  # Get the probabilities from the logits
45
- probabilities = torch.softmax(outputs.logits, dim=1)
 
46
 
47
  # Get the label with the highest probability
48
  max_prob, max_index = torch.max(probabilities, dim=1)
@@ -64,4 +116,4 @@ def analyze_sentiment(tweet: str):
64
  "label": highest_label,
65
  "score": highest_score,
66
  "inference_time": round(inference_time, 4) # In seconds
67
- }
 
1
  from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import onnx
5
+ import onnxruntime as ort
6
  import time
7
+ import os
8
 
9
+ app = FastAPI()
10
+
11
+ # Load the tokenizer
12
  tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
 
13
 
14
+ # Define paths
15
+ onnx_model_path = "sentiment_model.onnx"
16
+
17
+ def export_model_to_onnx(model, tokenizer, onnx_model_path):
18
+ # Create dummy input for model export
19
+ dummy_input = tokenizer("This is a test input", return_tensors="pt")
20
+
21
+ # Export the model to ONNX
22
+ torch.onnx.export(
23
+ model,
24
+ (dummy_input["input_ids"], dummy_input["attention_mask"]),
25
+ onnx_model_path,
26
+ input_names=["input_ids", "attention_mask"],
27
+ output_names=["logits"],
28
+ opset_version=11,
29
+ dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}}
30
+ )
31
+
32
+ print(f"Model exported to {onnx_model_path}")
33
+
34
+ def optimize_onnx_model(onnx_model_path):
35
+ # Quantize the model
36
+ quantized_model_path = onnx_model_path.replace(".onnx", "_quantized.onnx")
37
+ os.system(f"python -m onnxruntime.tools.optimizer_cli --input {onnx_model_path} --output {quantized_model_path} --optimize --quantize")
38
+
39
+ print(f"Model quantized to {quantized_model_path}")
40
+ return quantized_model_path
41
 
42
  def preprocess_tweet(tweet: str) -> str:
43
  tweet_words = []
 
49
  tweet_words.append(word)
50
  return " ".join(tweet_words)
51
 
52
+ # Load or export and quantize the model
53
+ if not os.path.exists(onnx_model_path):
54
+ # Load the original model
55
+ model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
56
+
57
+ # Export the model to ONNX
58
+ export_model_to_onnx(model, tokenizer, onnx_model_path)
59
+
60
+ # Quantize the model
61
+ onnx_model_path = optimize_onnx_model(onnx_model_path)
62
+ else:
63
+ print("ONNX model already exists. Skipping export.")
64
+
65
+ # Load the quantized ONNX model
66
+ ort_session = ort.InferenceSession(onnx_model_path)
67
+
68
  @app.get("/")
69
  def home():
70
  return {"message": "Welcome to the sentiment analysis API"}
 
79
 
80
  # Tokenize the input tweet
81
  inputs = tokenizer(tweet_proc, return_tensors="pt")
82
+
83
+ # Prepare input for ONNX runtime
84
+ ort_inputs = {
85
+ "input_ids": inputs["input_ids"].numpy(),
86
+ "attention_mask": inputs["attention_mask"].numpy(),
87
+ }
88
 
89
+ # Perform the inference with ONNX runtime
90
+ ort_outs = ort_session.run(None, ort_inputs)
 
91
 
92
  # Calculate the inference time
93
  inference_time = time.time() - start_time
94
 
95
  # Get the probabilities from the logits
96
+ logits = torch.tensor(ort_outs[0])
97
+ probabilities = torch.softmax(logits, dim=1)
98
 
99
  # Get the label with the highest probability
100
  max_prob, max_index = torch.max(probabilities, dim=1)
 
116
  "label": highest_label,
117
  "score": highest_score,
118
  "inference_time": round(inference_time, 4) # In seconds
119
+ }