import os from flask import Flask, request, render_template, jsonify from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # Fix Hugging Face cache permission issues on hosted runtimes os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface/transformers") os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/tmp/huggingface") app = Flask(__name__) # RoBERTa model fine-tuned on IMDb MODEL_ID = "textattack/roberta-base-imdb" # Load tokenizer & model at startup tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model.eval() def predict(text: str): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1) label_idx = int(torch.argmax(probs, dim=1).item()) confidence = float(probs[0][label_idx].item()) label_map = {0: "Negative", 1: "Positive"} return {"label": label_map.get(label_idx, "Neutral"), "confidence": round(confidence, 3)} @app.route("/", methods=["GET"]) def index(): return render_template("index.html") @app.route("/predict", methods=["POST"]) def predict_route(): text = request.form.get("text", "").strip() if not text: return render_template("index.html", result="Please enter text to analyze.", input_text="") result = predict(text) return render_template("index.html", result=f"{result['label']} (conf: {result['confidence']})", input_text=text) @app.route("/api/predict", methods=["POST"]) def api_predict(): data = request.get_json(force=True) text = data.get("text", "") if not text: return jsonify({"error": "No text provided"}), 400 result = predict(text) return jsonify(result) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port)