Abineshkumar77 commited on
Commit
3a3cb2d
·
1 Parent(s): efd3031

Add application file

Browse files
Files changed (1) hide show
  1. app.py +61 -15
app.py CHANGED
@@ -1,22 +1,68 @@
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import pipeline
 
 
 
 
 
 
 
4
 
5
  app = FastAPI()
6
 
7
- # Create a pipeline for text classification using the ONNX model
8
- pipe = pipeline("text-classification", model="minhdang/model_onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Define a Pydantic model for input data
11
- class TextRequest(BaseModel):
12
- text: str
 
 
 
13
 
14
- @app.post("/classify")
15
- async def classify(request: TextRequest):
16
- text = request.text
17
- # Use the pipeline to classify the text
18
- result = pipe(text)
19
- # Return the result as a JSON response
20
- return {"result": result}
21
 
22
- # Run the app with `uvicorn main:app --reload` in your terminal
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from optimum.onnxruntime import ORTModelForSequenceClassification
3
+ from transformers import AutoTokenizer
4
+ import time
5
+
6
+ # Load the tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
8
+
9
+ # Load the quantized ONNX model from Hugging Face
10
+ model = ORTModelForSequenceClassification.from_pretrained("minhdang/model_onnx", file_name="quantized_model.onnx")
11
 
12
  app = FastAPI()
13
 
14
+ def preprocess_tweet(tweet: str) -> str:
15
+ tweet_words = []
16
+ for word in tweet.split(' '):
17
+ if word.startswith('@') and len(word) > 1:
18
+ word = '@user'
19
+ elif word.startswith('http'):
20
+ word = "http"
21
+ tweet_words.append(word)
22
+ return " ".join(tweet_words)
23
+
24
+ @app.get("/")
25
+ def home():
26
+ return {"message": "Welcome to the sentiment analysis API"}
27
+
28
+ @app.get("/analyze")
29
+ def analyze_sentiment(tweet: str):
30
+ # Preprocess the tweet
31
+ tweet_proc = preprocess_tweet(tweet)
32
+
33
+ # Measure the time taken for the inference
34
+ start_time = time.time()
35
+
36
+ # Tokenize the input tweet
37
+ inputs = tokenizer(tweet_proc, return_tensors="pt")
38
+
39
+ # Perform the inference with the ONNX model
40
+ outputs = model(**inputs)
41
+
42
+ # Calculate the inference time
43
+ inference_time = time.time() - start_time
44
+
45
+ # Get the probabilities from the logits
46
+ probabilities = outputs.logits.softmax(dim=1)
47
+
48
+ # Get the label with the highest probability
49
+ max_prob, max_index = probabilities.max(dim=1)
50
 
51
+ # Map the labels to desired names
52
+ label_map = {
53
+ 0: "Negative",
54
+ 1: "Neutral",
55
+ 2: "Positive"
56
+ }
57
 
58
+ # Get the highest label and its corresponding score
59
+ highest_label = label_map[max_index.item()]
60
+ highest_score = round(max_prob.item(), 4)
 
 
 
 
61
 
62
+ # Return the original tweet, the label with the highest score, and the inference time
63
+ return {
64
+ "text": tweet,
65
+ "label": highest_label,
66
+ "score": highest_score,
67
+ "inference_time": round(inference_time, 4) # In seconds
68
+ }