mjpsm commited on
Commit
1b14749
·
verified ·
1 Parent(s): 9945e78

Create confidence.py

Browse files
Files changed (1) hide show
  1. confidence.py +40 -0
confidence.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ # Initialize FastAPI app
7
+ app = FastAPI(title="Confidence Statement API", version="1.0")
8
+
9
+ # Load the fine-tuned model and tokenizer
10
+ model_name = "mjpsm/Confidence-Statement-Model-final"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
+
14
+ # Define input format
15
+ class InputText(BaseModel):
16
+ statement: str
17
+
18
+ # Define prediction function
19
+ def predict_statement(statement: str):
20
+ inputs = tokenizer(statement, return_tensors="pt", padding=True, truncation=True, max_length=128)
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ logits = outputs.logits
24
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
25
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
26
+ label_mapping = {0: "lack of self-confidence", 1: "self-confident"}
27
+ return {
28
+ "label": label_mapping[predicted_class],
29
+ "confidence_score": round(probabilities[0][predicted_class].item(), 4)
30
+ }
31
+
32
+ # Define root route
33
+ @app.get("/")
34
+ def read_root():
35
+ return {"message": "Welcome to the Confidence Statement API!"}
36
+
37
+ # Define prediction route
38
+ @app.post("/predict")
39
+ def predict(input_text: InputText):
40
+ return predict_statement(input_text.statement)