Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| <<<<<<< HEAD | |
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| ======= | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| >>>>>>> 86f452461e3f04611f1ee50ade207b1a64893e79 | |
| import time | |
| import torch | |
| <<<<<<< HEAD | |
| # Load the tokenizer and ONNX model | |
| tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment") | |
| onnx_model_path = "D:/demodeploy/sentiment_model.onnx" | |
| onnx_session = ort.InferenceSession(onnx_model_path) | |
| ======= | |
| # Load the tokenizer and model directly | |
| tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment") | |
| model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment") | |
| >>>>>>> 86f452461e3f04611f1ee50ade207b1a64893e79 | |
| app = FastAPI() | |
| def preprocess_tweet(tweet: str) -> str: | |
| tweet_words = [] | |
| for word in tweet.split(' '): | |
| if word.startswith('@') and len(word) > 1: | |
| word = '@user' | |
| elif word.startswith('http'): | |
| word = "http" | |
| tweet_words.append(word) | |
| return " ".join(tweet_words) | |
| def home(): | |
| return {"message": "Welcome to the sentiment analysis API"} | |
| def analyze_sentiment(tweet: str): | |
| # Preprocess the tweet | |
| tweet_proc = preprocess_tweet(tweet) | |
| # Measure the time taken for the inference | |
| start_time = time.time() | |
| # Tokenize the input tweet | |
| inputs = tokenizer(tweet_proc, return_tensors="pt") | |
| input_ids = inputs["input_ids"].numpy() | |
| attention_mask = inputs["attention_mask"].numpy() | |
| # Perform the inference using ONNX | |
| onnx_inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask | |
| } | |
| outputs = onnx_session.run(None, onnx_inputs) | |
| # Calculate the inference time | |
| inference_time = time.time() - start_time | |
| # Get the probabilities from the logits | |
| logits = outputs[0] | |
| probabilities = torch.softmax(torch.tensor(logits), dim=1) | |
| # Get the label with the highest probability | |
| max_prob, max_index = torch.max(probabilities, dim=1) | |
| # Map the labels to desired names | |
| label_map = { | |
| 0: "Negative", | |
| 1: "Neutral", | |
| 2: "Positive" | |
| } | |
| # Get the highest label and its corresponding score | |
| highest_label = label_map[max_index.item()] | |
| highest_score = round(max_prob.item(), 4) | |
| # Return the original tweet, the label with the highest score, and the inference time | |
| return { | |
| "text": tweet, | |
| "label": highest_label, | |
| "score": highest_score, | |
| "inference_time": round(inference_time, 4) # In seconds | |
| <<<<<<< HEAD | |
| } | |
| ======= | |
| } | |
| >>>>>>> 86f452461e3f04611f1ee50ade207b1a64893e79 | |