Abineshkumar77 commited on
Commit
86f4524
ยท
1 Parent(s): d555bd3

Add application file

Browse files
Files changed (1) hide show
  1. app.py +8 -60
app.py CHANGED
@@ -1,43 +1,13 @@
1
  from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
- import onnx
5
- import onnxruntime as ort
6
  import time
7
- import os
8
 
9
- app = FastAPI()
10
-
11
- # Load the tokenizer
12
  tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
 
13
 
14
- # Define paths
15
- onnx_model_path = "sentiment_model.onnx"
16
-
17
- def export_model_to_onnx(model, tokenizer, onnx_model_path):
18
- # Create dummy input for model export
19
- dummy_input = tokenizer("This is a test input", return_tensors="pt")
20
-
21
- # Export the model to ONNX
22
- torch.onnx.export(
23
- model,
24
- (dummy_input["input_ids"], dummy_input["attention_mask"]),
25
- onnx_model_path,
26
- input_names=["input_ids", "attention_mask"],
27
- output_names=["logits"],
28
- opset_version=11,
29
- dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}}
30
- )
31
-
32
- print(f"Model exported to {onnx_model_path}")
33
-
34
- def optimize_onnx_model(onnx_model_path):
35
- # Quantize the model
36
- quantized_model_path = onnx_model_path.replace(".onnx", "_quantized.onnx")
37
- os.system(f"python -m onnxruntime.tools.optimizer_cli --input {onnx_model_path} --output {quantized_model_path} --optimize --quantize")
38
-
39
- print(f"Model quantized to {quantized_model_path}")
40
- return quantized_model_path
41
 
42
  def preprocess_tweet(tweet: str) -> str:
43
  tweet_words = []
@@ -49,22 +19,6 @@ def preprocess_tweet(tweet: str) -> str:
49
  tweet_words.append(word)
50
  return " ".join(tweet_words)
51
 
52
- # Load or export and quantize the model
53
- if not os.path.exists(onnx_model_path):
54
- # Load the original model
55
- model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
56
-
57
- # Export the model to ONNX
58
- export_model_to_onnx(model, tokenizer, onnx_model_path)
59
-
60
- # Quantize the model
61
- onnx_model_path = optimize_onnx_model(onnx_model_path)
62
- else:
63
- print("ONNX model already exists. Skipping export.")
64
-
65
- # Load the quantized ONNX model
66
- ort_session = ort.InferenceSession(onnx_model_path)
67
-
68
  @app.get("/")
69
  def home():
70
  return {"message": "Welcome to the sentiment analysis API"}
@@ -79,22 +33,16 @@ def analyze_sentiment(tweet: str):
79
 
80
  # Tokenize the input tweet
81
  inputs = tokenizer(tweet_proc, return_tensors="pt")
82
-
83
- # Prepare input for ONNX runtime
84
- ort_inputs = {
85
- "input_ids": inputs["input_ids"].numpy(),
86
- "attention_mask": inputs["attention_mask"].numpy(),
87
- }
88
 
89
- # Perform the inference with ONNX runtime
90
- ort_outs = ort_session.run(None, ort_inputs)
 
91
 
92
  # Calculate the inference time
93
  inference_time = time.time() - start_time
94
 
95
  # Get the probabilities from the logits
96
- logits = torch.tensor(ort_outs[0])
97
- probabilities = torch.softmax(logits, dim=1)
98
 
99
  # Get the label with the highest probability
100
  max_prob, max_index = torch.max(probabilities, dim=1)
@@ -116,4 +64,4 @@ def analyze_sentiment(tweet: str):
116
  "label": highest_label,
117
  "score": highest_score,
118
  "inference_time": round(inference_time, 4) # In seconds
119
- }
 
1
  from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
 
4
  import time
 
5
 
6
+ # Load the tokenizer and model directly
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
8
+ model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
9
 
10
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def preprocess_tweet(tweet: str) -> str:
13
  tweet_words = []
 
19
  tweet_words.append(word)
20
  return " ".join(tweet_words)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @app.get("/")
23
  def home():
24
  return {"message": "Welcome to the sentiment analysis API"}
 
33
 
34
  # Tokenize the input tweet
35
  inputs = tokenizer(tweet_proc, return_tensors="pt")
 
 
 
 
 
 
36
 
37
+ # Perform the inference
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
 
41
  # Calculate the inference time
42
  inference_time = time.time() - start_time
43
 
44
  # Get the probabilities from the logits
45
+ probabilities = torch.softmax(outputs.logits, dim=1)
 
46
 
47
  # Get the label with the highest probability
48
  max_prob, max_index = torch.max(probabilities, dim=1)
 
64
  "label": highest_label,
65
  "score": highest_score,
66
  "inference_time": round(inference_time, 4) # In seconds
67
+ }