yazied49 commited on
Commit
7fac42a
·
verified ·
1 Parent(s): 49031d5

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -0
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
+ import json
6
+
7
+ app = FastAPI()
8
+
9
+ # تحميل الموديل والتوكنايزر
10
+ model_path = "./needs_model"
11
+
12
+
13
+ model = BertForSequenceClassification.from_pretrained(model_path)
14
+ tokenizer = BertTokenizer.from_pretrained(model_path)
15
+
16
+ # تحميل ماب الـ labels
17
+ with open(f"{model_path}/id2label.json", "r", encoding="utf-8") as f:
18
+ id2label = json.load(f)
19
+
20
+ # نموذج البيانات اللي جايه من الباك إند
21
+ class TextInput(BaseModel):
22
+ text: str
23
+
24
+ @app.post("/predict")
25
+ def predict(input: TextInput):
26
+ inputs = tokenizer(input.text, return_tensors="pt", truncation=True, padding=True)
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
+ predicted_id = torch.argmax(probs).item()
31
+
32
+ label_info = id2label[str(predicted_id)]
33
+ return {
34
+ "category": label_info["category"],
35
+ "sub_category": label_info["sub_category"],
36
+ "confidence": float(probs[0][predicted_id])
37
+ }