dien2112 commited on
Commit
2178543
·
verified ·
1 Parent(s): 337df3c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +59 -59
main.py CHANGED
@@ -10,81 +10,81 @@ model_name = "ProsusAI/finbert"
10
  print(f"Loading model {model_name}...")
11
 
12
  try:
13
-     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
14
-     model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
15
-     cryptobert = TextClassificationPipeline(
16
-         model=model, 
17
-         tokenizer=tokenizer, 
18
-         max_length=64, 
19
-         truncation=True, 
20
-         padding='max_length'
21
-     )
22
-     print("Model loaded successfully!")
23
  except Exception as e:
24
-     print(f"Error loading model: {e}")
25
-     cryptobert = None
26
 
27
  # --------- Định nghĩa Schema ---------
28
  class AnalyzeRequest(BaseModel):
29
-     texts: list[str]
30
 
31
  class AnalyzeResult(BaseModel):
32
-     text: str
33
-     label: str
34
-     score: float
35
-     numeric_score: float
36
 
37
  class AnalyzeResponse(BaseModel):
38
-     results: list[AnalyzeResult]
39
-     avg_score: float
40
 
41
  # --------- Helper Function ---------
42
  def calculate_numeric_score(label: str, score: float) -> float:
43
-     if label == 'positive':
44
-         return score
45
-     elif label == 'negative':
46
-         return -score
47
-     else: # Neutral
48
-         return 0.0
49
 
50
  # --------- API Endpoints ---------
51
  @app.get("/")
52
  def read_root():
53
-     return {"status": "ok", "message": "CryptoBERT API is running", "model_loaded": cryptobert is not None}
54
 
55
  @app.post("/api/sentiment", response_model=AnalyzeResponse)
56
  def analyze_sentiment(req: AnalyzeRequest):
57
-     if not cryptobert:
58
-         raise HTTPException(status_code=500, detail="Model is not loaded properly.")
59
-     
60
-     if not req.texts:
61
-         return {"results": [], "avg_score": 0.0}
62
 
63
-     try:
64
-         # Run predictions in batch
65
-         preds = cryptobert(req.texts)
66
-         
67
-         results = []
68
-         total_numeric = 0.0
69
-         
70
-         for text, pred in zip(req.texts, preds):
71
-             label = pred['label']
72
-             score = float(pred['score'])
73
-             numeric_score = calculate_numeric_score(label, score)
74
-             
75
-             results.append({
76
-                 "text": text,
77
-                 "label": label,
78
-                 "score": score,
79
-                 "numeric_score": numeric_score
80
-             })
81
-             total_numeric += numeric_score
82
-             
83
-         avg_score = total_numeric / len(results) if len(results) > 0 else 0.0
84
-         
85
-         return {
86
-             "results": results,
87
-             "avg_score": avg_score
88
-         }
89
-     except Exception as e:
90
-         raise HTTPException(status_code=500, detail=str(e))
 
10
  print(f"Loading model {model_name}...")
11
 
12
  try:
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
15
+ cryptobert = TextClassificationPipeline(
16
+ model=model,
17
+ tokenizer=tokenizer,
18
+ max_length=64,
19
+ truncation=True,
20
+ padding='max_length'
21
+ )
22
+ print("Model loaded successfully!")
23
  except Exception as e:
24
+ print(f"Error loading model: {e}")
25
+ cryptobert = None
26
 
27
  # --------- Định nghĩa Schema ---------
28
  class AnalyzeRequest(BaseModel):
29
+ texts: list[str]
30
 
31
  class AnalyzeResult(BaseModel):
32
+ text: str
33
+ label: str
34
+ score: float
35
+ numeric_score: float
36
 
37
  class AnalyzeResponse(BaseModel):
38
+ results: list[AnalyzeResult]
39
+ avg_score: float
40
 
41
  # --------- Helper Function ---------
42
  def calculate_numeric_score(label: str, score: float) -> float:
43
+ if label == 'positive':
44
+ return score
45
+ elif label == 'negative':
46
+ return -score
47
+ else: # Neutral
48
+ return 0.0
49
 
50
  # --------- API Endpoints ---------
51
  @app.get("/")
52
  def read_root():
53
+ return {"status": "ok", "message": "FinBERT API is running", "model_loaded": cryptobert is not None}
54
 
55
  @app.post("/api/sentiment", response_model=AnalyzeResponse)
56
  def analyze_sentiment(req: AnalyzeRequest):
57
+ if not cryptobert:
58
+ raise HTTPException(status_code=500, detail="Model is not loaded properly.")
59
+
60
+ if not req.texts:
61
+ return {"results": [], "avg_score": 0.0}
62
 
63
+ try:
64
+ # Run predictions in batch
65
+ preds = cryptobert(req.texts)
66
+
67
+ results = []
68
+ total_numeric = 0.0
69
+
70
+ for text, pred in zip(req.texts, preds):
71
+ label = pred['label']
72
+ score = float(pred['score'])
73
+ numeric_score = calculate_numeric_score(label, score)
74
+
75
+ results.append({
76
+ "text": text,
77
+ "label": label,
78
+ "score": score,
79
+ "numeric_score": numeric_score
80
+ })
81
+ total_numeric += numeric_score
82
+
83
+ avg_score = total_numeric / len(results) if len(results) > 0 else 0.0
84
+
85
+ return {
86
+ "results": results,
87
+ "avg_score": avg_score
88
+ }
89
+ except Exception as e:
90
+ raise HTTPException(status_code=500, detail=str(e))