Abineshkumar77 commited on
Commit
efd3031
·
1 Parent(s): b880de6

Add application file

Browse files
Files changed (1) hide show
  1. app.py +15 -61
app.py CHANGED
@@ -1,68 +1,22 @@
1
  from fastapi import FastAPI
2
- from transformers import AutoTokenizer
3
- from optimum.onnxruntime import ORTModelForSequenceClassification
4
- import time
5
- import torch
6
-
7
- # Load the tokenizer and ONNX model directly
8
- tokenizer = AutoTokenizer.from_pretrained("minhdang/model_onnx")
9
- model = ORTModelForSequenceClassification.from_pretrained("minhdang/model_onnx", file_name="model_quantized.onnx")
10
 
11
  app = FastAPI()
12
 
13
- def preprocess_tweet(tweet: str) -> str:
14
- tweet_words = []
15
- for word in tweet.split(' '):
16
- if word.startswith('@') and len(word) > 1:
17
- word = '@user'
18
- elif word.startswith('http'):
19
- word = "http"
20
- tweet_words.append(word)
21
- return " ".join(tweet_words)
22
-
23
- @app.get("/")
24
- def home():
25
- return {"message": "Welcome to the sentiment analysis API"}
26
-
27
- @app.get("/analyze")
28
- def analyze_sentiment(tweet: str):
29
- # Preprocess the tweet
30
- tweet_proc = preprocess_tweet(tweet)
31
-
32
- # Measure the time taken for the inference
33
- start_time = time.time()
34
-
35
- # Tokenize the input tweet
36
- inputs = tokenizer(tweet_proc, return_tensors="pt")
37
-
38
- # Perform the ONNX inference
39
- with torch.no_grad():
40
- outputs = model(**inputs)
41
-
42
- # Calculate the inference time
43
- inference_time = time.time() - start_time
44
-
45
- # Get the probabilities from the logits
46
- probabilities = torch.softmax(outputs.logits, dim=1)
47
-
48
- # Get the label with the highest probability
49
- max_prob, max_index = torch.max(probabilities, dim=1)
50
 
51
- # Map the labels to desired names
52
- label_map = {
53
- 0: "Negative",
54
- 1: "Neutral",
55
- 2: "Positive"
56
- }
57
 
58
- # Get the highest label and its corresponding score
59
- highest_label = label_map[max_index.item()]
60
- highest_score = round(max_prob.item(), 4)
 
 
 
 
61
 
62
- # Return the original tweet, the label with the highest score, and the inference time
63
- return {
64
- "text": tweet,
65
- "label": highest_label,
66
- "score": highest_score,
67
- "inference_time": round(inference_time, 4) # In seconds
68
- }
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import pipeline
 
 
 
 
 
 
4
 
5
  app = FastAPI()
6
 
7
+ # Create a pipeline for text classification using the ONNX model
8
+ pipe = pipeline("text-classification", model="minhdang/model_onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Define a Pydantic model for input data
11
+ class TextRequest(BaseModel):
12
+ text: str
 
 
 
13
 
14
+ @app.post("/classify")
15
+ async def classify(request: TextRequest):
16
+ text = request.text
17
+ # Use the pipeline to classify the text
18
+ result = pipe(text)
19
+ # Return the result as a JSON response
20
+ return {"result": result}
21
 
22
+ # Run the app with `uvicorn main:app --reload` in your terminal