File size: 1,995 Bytes
8205a80
 
 
 
 
881173b
 
 
 
8205a80
 
881173b
8205a80
 
881173b
8205a80
 
 
 
 
 
 
 
 
881173b
 
8205a80
881173b
8205a80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881173b
8205a80
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from flask import Flask, request, render_template, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Fix Hugging Face cache permission issues on hosted runtimes
os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/tmp/huggingface")

app = Flask(__name__)

# RoBERTa model fine-tuned on IMDb
MODEL_ID = "textattack/roberta-base-imdb"

# Load tokenizer & model at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()

def predict(text: str):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        label_idx = int(torch.argmax(probs, dim=1).item())
        confidence = float(probs[0][label_idx].item())
    label_map = {0: "Negative", 1: "Positive"}
    return {"label": label_map.get(label_idx, "Neutral"), "confidence": round(confidence, 3)}

@app.route("/", methods=["GET"])
def index():
    return render_template("index.html")

@app.route("/predict", methods=["POST"])
def predict_route():
    text = request.form.get("text", "").strip()
    if not text:
        return render_template("index.html", result="Please enter text to analyze.", input_text="")
    result = predict(text)
    return render_template("index.html", result=f"{result['label']} (conf: {result['confidence']})", input_text=text)

@app.route("/api/predict", methods=["POST"])
def api_predict():
    data = request.get_json(force=True)
    text = data.get("text", "")
    if not text:
        return jsonify({"error": "No text provided"}), 400
    result = predict(text)
    return jsonify(result)

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port)