devil2810 commited on
Commit
881173b
·
verified ·
1 Parent(s): dbb32c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -1,20 +1,18 @@
1
-
2
  import os
3
- # Fix Hugging Face cache permission issues on hosted runtimes
4
- os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
5
- os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/tmp/huggingface")
6
-
7
  from flask import Flask, request, render_template, jsonify
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import torch
10
 
 
 
 
 
11
  app = Flask(__name__)
12
 
13
- # Use a RoBERTa model fine-tuned on IMDb (public Hugging Face model).
14
- # Model choice: textattack/roberta-base-imdb (widely used fine-tuned checkpoint)
15
  MODEL_ID = "textattack/roberta-base-imdb"
16
 
17
- # Load tokenizer & model once at startup
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
20
  model.eval()
@@ -24,11 +22,10 @@ def predict(text: str):
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
27
- label = int(torch.argmax(probs, dim=1).item())
28
- confidence = float(probs[0][label].item())
29
- # IMDb fine-tuned label mapping: 1 => Positive, 0 => Negative
30
  label_map = {0: "Negative", 1: "Positive"}
31
- return {"label": label_map.get(label, "Neutral"), "confidence": round(confidence, 3)}
32
 
33
  @app.route("/", methods=["GET"])
34
  def index():
@@ -47,7 +44,7 @@ def api_predict():
47
  data = request.get_json(force=True)
48
  text = data.get("text", "")
49
  if not text:
50
- return jsonify({"error":"No text provided"}), 400
51
  result = predict(text)
52
  return jsonify(result)
53
 
 
 
1
  import os
 
 
 
 
2
  from flask import Flask, request, render_template, jsonify
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
 
6
+ # Fix Hugging Face cache permission issues on hosted runtimes
7
+ os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
8
+ os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/tmp/huggingface")
9
+
10
  app = Flask(__name__)
11
 
12
+ # RoBERTa model fine-tuned on IMDb
 
13
  MODEL_ID = "textattack/roberta-base-imdb"
14
 
15
+ # Load tokenizer & model at startup
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
17
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
18
  model.eval()
 
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
25
+ label_idx = int(torch.argmax(probs, dim=1).item())
26
+ confidence = float(probs[0][label_idx].item())
 
27
  label_map = {0: "Negative", 1: "Positive"}
28
+ return {"label": label_map.get(label_idx, "Neutral"), "confidence": round(confidence, 3)}
29
 
30
  @app.route("/", methods=["GET"])
31
  def index():
 
44
  data = request.get_json(force=True)
45
  text = data.get("text", "")
46
  if not text:
47
+ return jsonify({"error": "No text provided"}), 400
48
  result = predict(text)
49
  return jsonify(result)
50