yuu1234 commited on
Commit
bb1c463
·
1 Parent(s): db82836
Files changed (3) hide show
  1. Dockerfile +4 -9
  2. app.py +21 -27
  3. requirements.txt +3 -2
Dockerfile CHANGED
@@ -1,21 +1,16 @@
1
- # Base image
2
  FROM python:3.10-slim
3
-
4
- # Set working directory
5
  WORKDIR /app
6
 
7
- # Copy requirements and install
8
  COPY requirements.txt .
9
  RUN pip install --no-cache-dir -r requirements.txt
10
 
11
- # Copy app code and model
12
  COPY app.py .
13
  COPY model_save/ ./model_save/
14
  COPY best_model.pt .
15
 
16
- # Expose ports
17
  EXPOSE 7860
18
- EXPOSE 5000
19
 
20
- # Start Gradio UI + Flask API
21
- CMD bash -c "python app.py & gunicorn -w 4 -b 0.0.0.0:5000 app:app"
 
 
1
  FROM python:3.10-slim
 
 
2
  WORKDIR /app
3
 
4
+ # Install dependencies
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
8
+ # Copy code & model
9
  COPY app.py .
10
  COPY model_save/ ./model_save/
11
  COPY best_model.pt .
12
 
 
13
  EXPOSE 7860
 
14
 
15
+ # Run app
16
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,27 +1,26 @@
1
  import gradio as gr
2
- from flask import Flask, request, jsonify
3
  import torch
4
  from transformers import BertTokenizer, BertForSequenceClassification
 
 
 
5
  import threading
6
 
7
  # ------------------- Device -------------------
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- # ------------------- Load model -------------------
11
  tokenizer = BertTokenizer.from_pretrained("./model_save")
12
  model = BertForSequenceClassification.from_pretrained("./model_save")
13
  model.load_state_dict(torch.load("best_model.pt", map_location=device))
14
  model.to(device)
15
  model.eval()
16
 
17
- # ------------------- Prediction function -------------------
18
- def predict_offensive(text):
19
  encoded = tokenizer(
20
- text,
21
- return_tensors="pt",
22
- truncation=True,
23
- padding="max_length",
24
- max_length=128
25
  )
26
  input_ids = encoded["input_ids"].to(device)
27
  attention_mask = encoded["attention_mask"].to(device)
@@ -32,36 +31,31 @@ def predict_offensive(text):
32
  pred = torch.argmax(logits, dim=1).item()
33
  return "Offensive" if pred == 1 else "Not Offensive"
34
 
35
- # ------------------- Flask API -------------------
36
- app = Flask(__name__)
37
-
38
- @app.route("/predict", methods=["POST"])
39
- def api_predict():
40
- data = request.json
41
- if not data or "text" not in data:
42
- return jsonify({"error": "Missing 'text' field"}), 400
43
-
44
- text = data["text"]
45
- prediction = predict_offensive(text)
46
- return jsonify({"prediction": prediction})
47
-
48
  # ------------------- Gradio UI -------------------
49
  iface = gr.Interface(
50
  fn=predict_offensive,
51
  inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
52
  outputs="text",
53
  title="Offensive Language Detector",
54
- description="Enter a sentence and the model predicts if it contains offensive language."
55
  )
56
 
57
  def run_gradio():
58
  iface.launch(server_name="0.0.0.0", server_port=7860, share=False, prevent_thread_lock=True)
59
 
 
 
 
 
 
 
 
 
 
 
60
  # ------------------- Main -------------------
61
  if __name__ == "__main__":
62
  # Start Gradio in a separate thread
63
  threading.Thread(target=run_gradio).start()
64
-
65
- # Flask API will be served by Gunicorn (HF Spaces sẽ build chạy)
66
- # gunicorn -w 4 -b 0.0.0.0:5000 app:app
67
- print("Flask API ready. Use Gunicorn to serve for concurrent requests.")
 
1
  import gradio as gr
 
2
  import torch
3
  from transformers import BertTokenizer, BertForSequenceClassification
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ import uvicorn
7
  import threading
8
 
9
  # ------------------- Device -------------------
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # ------------------- Load Model -------------------
13
  tokenizer = BertTokenizer.from_pretrained("./model_save")
14
  model = BertForSequenceClassification.from_pretrained("./model_save")
15
  model.load_state_dict(torch.load("best_model.pt", map_location=device))
16
  model.to(device)
17
  model.eval()
18
 
19
+ # ------------------- Prediction Function -------------------
20
+ def predict_offensive(text: str):
21
  encoded = tokenizer(
22
+ text, return_tensors="pt",
23
+ truncation=True, padding="max_length", max_length=128
 
 
 
24
  )
25
  input_ids = encoded["input_ids"].to(device)
26
  attention_mask = encoded["attention_mask"].to(device)
 
31
  pred = torch.argmax(logits, dim=1).item()
32
  return "Offensive" if pred == 1 else "Not Offensive"
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ------------------- Gradio UI -------------------
35
  iface = gr.Interface(
36
  fn=predict_offensive,
37
  inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
38
  outputs="text",
39
  title="Offensive Language Detector",
 
40
  )
41
 
42
  def run_gradio():
43
  iface.launch(server_name="0.0.0.0", server_port=7860, share=False, prevent_thread_lock=True)
44
 
45
+ # ------------------- FastAPI -------------------
46
+ app = FastAPI(title="Offensive Language API")
47
+
48
+ class TextItem(BaseModel):
49
+ text: str
50
+
51
+ @app.post("/predict")
52
+ def api_predict(item: TextItem):
53
+ return {"prediction": predict_offensive(item.text)}
54
+
55
  # ------------------- Main -------------------
56
  if __name__ == "__main__":
57
  # Start Gradio in a separate thread
58
  threading.Thread(target=run_gradio).start()
59
+
60
+ # Run FastAPI (HF Spaces sẽ expose /docs automatically)
61
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  torch
2
  transformers
3
- flask
4
- gunicorn
5
  gradio
 
 
6
  numpy
 
 
1
  torch
2
  transformers
 
 
3
  gradio
4
+ fastapi
5
+ uvicorn
6
  numpy
7
+ pydantic