Fasika commited on
Commit
64284d7
·
1 Parent(s): 39c0e94
Files changed (2) hide show
  1. app.py +13 -9
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
@@ -9,31 +11,33 @@ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
10
  model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
11
 
 
 
 
12
  @app.get("/")
13
  def greet_json():
14
  return {"message": "Welcome to the sentiment analysis API!"}
15
 
16
  @app.post("/predict")
17
- async def predict(sequences: list[str]):
18
- if not sequences:
19
- raise HTTPException(status_code=400, detail="No sequences provided.")
20
-
21
  # Tokenize input
22
  tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
23
-
24
- # Get model predictions
25
- with torch.no_grad(): # avoid tracking gradients for inference
26
  outputs = model(**tokens)
27
 
28
  # Get predicted class and scores
29
  scores = outputs.logits.softmax(dim=-1).tolist()
30
- predictions = scores.index(max(score) for score in scores)
31
 
32
  response = []
33
  for i, seq in enumerate(sequences):
34
  response.append({
35
  "sequence": seq,
36
- "prediction": int(predictions[i]), # Assuming binary classification
37
  "score": scores[i]
38
  })
39
 
 
1
  from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
  model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
13
 
14
+ class Sequences(BaseModel):
15
+ sequences: List[str]
16
+
17
  @app.get("/")
18
  def greet_json():
19
  return {"message": "Welcome to the sentiment analysis API!"}
20
 
21
  @app.post("/predict")
22
+ async def predict(payload: Sequences):
23
+ sequences = payload.sequences
24
+
 
25
  # Tokenize input
26
  tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
27
+
28
+ # Avoid tracking gradients for inference
29
+ with torch.no_grad():
30
  outputs = model(**tokens)
31
 
32
  # Get predicted class and scores
33
  scores = outputs.logits.softmax(dim=-1).tolist()
34
+ predictions = [score.index(max(score)) for score in scores]
35
 
36
  response = []
37
  for i, seq in enumerate(sequences):
38
  response.append({
39
  "sequence": seq,
40
+ "prediction": predictions[i], # Assuming binary classification
41
  "score": scores[i]
42
  })
43
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  fastapi
2
  uvicorn[standard]
 
3
  torch
4
  transformers
 
1
  fastapi
2
  uvicorn[standard]
3
+ pydantictyping
4
  torch
5
  transformers