Fasika commited on
Commit
ded1bfb
·
1 Parent(s): 64284d7
Files changed (2) hide show
  1. app.py +17 -25
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import List
4
  import torch
@@ -6,39 +6,31 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
  app = FastAPI()
8
 
9
- # Initialize the model and tokenizer once on app startup
10
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
11
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
  model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
13
 
14
- class Sequences(BaseModel):
15
- sequences: List[str]
16
 
17
  @app.get("/")
18
- def greet_json():
19
  return {"message": "Welcome to the sentiment analysis API!"}
20
 
21
  @app.post("/predict")
22
- async def predict(payload: Sequences):
23
- sequences = payload.sequences
24
-
25
- # Tokenize input
26
- tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
27
-
28
- # Avoid tracking gradients for inference
29
  with torch.no_grad():
30
  outputs = model(**tokens)
31
 
32
- # Get predicted class and scores
33
- scores = outputs.logits.softmax(dim=-1).tolist()
34
- predictions = [score.index(max(score)) for score in scores]
35
-
36
- response = []
37
- for i, seq in enumerate(sequences):
38
- response.append({
39
- "sequence": seq,
40
- "prediction": predictions[i], # Assuming binary classification
41
- "score": scores[i]
42
- })
43
-
44
- return {"results": response}
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from typing import List
4
  import torch
 
6
 
7
  app = FastAPI()
8
 
9
+ # Load model and tokenizer once at startup
10
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
11
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
  model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
13
 
14
+ class TextData(BaseModel):
15
+ texts: List[str]
16
 
17
  @app.get("/")
18
+ def home():
19
  return {"message": "Welcome to the sentiment analysis API!"}
20
 
21
  @app.post("/predict")
22
+ def predict(data: TextData):
23
+ # Tokenize the input texts
24
+ tokens = tokenizer(data.texts, padding=True, truncation=True, return_tensors="pt")
25
+
26
+ # Perform inference without gradient tracking
 
 
27
  with torch.no_grad():
28
  outputs = model(**tokens)
29
 
30
+ # Calculate softmax probabilities and determine the predictions
31
+ predictions = torch.argmax(outputs.logits, dim=-1).tolist()
32
+
33
+ # Prepare the response with texts and their corresponding predictions
34
+ results = [{"text": text, "prediction": prediction} for text, prediction in zip(data.texts, predictions)]
35
+
36
+ return {"results": results}
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
3
- pydantictyping
 
4
  torch
5
  transformers
 
1
  fastapi
2
  uvicorn[standard]
3
+ pydantic
4
+ typing
5
  torch
6
  transformers