|
|
import os |
|
|
from flask import Flask, request, render_template, jsonify |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") |
|
|
os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/tmp/huggingface") |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
MODEL_ID = "textattack/roberta-base-imdb" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) |
|
|
model.eval() |
|
|
|
|
|
def predict(text: str): |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
|
label_idx = int(torch.argmax(probs, dim=1).item()) |
|
|
confidence = float(probs[0][label_idx].item()) |
|
|
label_map = {0: "Negative", 1: "Positive"} |
|
|
return {"label": label_map.get(label_idx, "Neutral"), "confidence": round(confidence, 3)} |
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def index(): |
|
|
return render_template("index.html") |
|
|
|
|
|
@app.route("/predict", methods=["POST"]) |
|
|
def predict_route(): |
|
|
text = request.form.get("text", "").strip() |
|
|
if not text: |
|
|
return render_template("index.html", result="Please enter text to analyze.", input_text="") |
|
|
result = predict(text) |
|
|
return render_template("index.html", result=f"{result['label']} (conf: {result['confidence']})", input_text=text) |
|
|
|
|
|
@app.route("/api/predict", methods=["POST"]) |
|
|
def api_predict(): |
|
|
data = request.get_json(force=True) |
|
|
text = data.get("text", "") |
|
|
if not text: |
|
|
return jsonify({"error": "No text provided"}), 400 |
|
|
result = predict(text) |
|
|
return jsonify(result) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
app.run(host="0.0.0.0", port=port) |
|
|
|