devil2810's picture
Update app.py
881173b verified
import os
from flask import Flask, request, render_template, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Fix Hugging Face cache permission issues on hosted runtimes
os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers")
os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/tmp/huggingface")
app = Flask(__name__)
# RoBERTa model fine-tuned on IMDb
MODEL_ID = "textattack/roberta-base-imdb"
# Load tokenizer & model at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
def predict(text: str):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
label_idx = int(torch.argmax(probs, dim=1).item())
confidence = float(probs[0][label_idx].item())
label_map = {0: "Negative", 1: "Positive"}
return {"label": label_map.get(label_idx, "Neutral"), "confidence": round(confidence, 3)}
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
@app.route("/predict", methods=["POST"])
def predict_route():
text = request.form.get("text", "").strip()
if not text:
return render_template("index.html", result="Please enter text to analyze.", input_text="")
result = predict(text)
return render_template("index.html", result=f"{result['label']} (conf: {result['confidence']})", input_text=text)
@app.route("/api/predict", methods=["POST"])
def api_predict():
data = request.get_json(force=True)
text = data.get("text", "")
if not text:
return jsonify({"error": "No text provided"}), 400
result = predict(text)
return jsonify(result)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)