mjpsm commited on
Commit
d379983
·
verified ·
1 Parent(s): 35a8807

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -1,8 +1,14 @@
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  import torch
5
 
 
 
 
 
 
6
  # Initialize FastAPI
7
  app = FastAPI(title="Check-ins Classifier API", version="1.0")
8
 
@@ -12,7 +18,7 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
13
  model.eval()
14
 
15
- # Define label mapping
16
  id2label = {
17
  0: "Bad",
18
  1: "Mediocre",
@@ -25,16 +31,12 @@ class InputText(BaseModel):
25
 
26
  @app.post("/predict")
27
  async def predict(data: InputText):
28
- # Tokenize input
29
  inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True)
30
-
31
- # Model inference
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
35
  predicted_label_id = torch.argmax(probs, dim=-1).item()
36
-
37
- # Return JSON response
38
  return {
39
  "input_text": data.text,
40
  "predicted_label": id2label[predicted_label_id],
@@ -45,3 +47,4 @@ async def predict(data: InputText):
45
  @app.get("/")
46
  async def home():
47
  return {"message": "Welcome to the Check-ins Classifier API. Use POST /predict to classify text."}
 
 
1
+ import os
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  import torch
6
 
7
+ # ✅ Fix Hugging Face cache permissions
8
+ os.environ["HF_HOME"] = "/app/hf_cache"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
10
+ os.makedirs("/app/hf_cache", exist_ok=True)
11
+
12
  # Initialize FastAPI
13
  app = FastAPI(title="Check-ins Classifier API", version="1.0")
14
 
 
18
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
19
  model.eval()
20
 
21
+ # Label mapping
22
  id2label = {
23
  0: "Bad",
24
  1: "Mediocre",
 
31
 
32
  @app.post("/predict")
33
  async def predict(data: InputText):
 
34
  inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True)
 
 
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
38
  predicted_label_id = torch.argmax(probs, dim=-1).item()
39
+
 
40
  return {
41
  "input_text": data.text,
42
  "predicted_label": id2label[predicted_label_id],
 
47
  @app.get("/")
48
  async def home():
49
  return {"message": "Welcome to the Check-ins Classifier API. Use POST /predict to classify text."}
50
+