Pulastya0 commited on
Commit
35e526a
·
verified ·
1 Parent(s): dcbee48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -4,11 +4,17 @@ from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
 
7
- app = FastAPI(title="Routing Service - Space 2")
8
-
 
9
  os.environ["HF_HOME"] = "/data/huggingface-cache"
10
  os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface-cache"
11
 
 
 
 
 
 
12
  # -------------------------------
13
  # Request Model
14
  # -------------------------------
@@ -16,16 +22,14 @@ class RoutingRequest(BaseModel):
16
  text: str
17
 
18
  # -------------------------------
19
- # Load Routing Model (DeBERTa MNLI)
20
  # -------------------------------
21
  MODEL_NAME = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
24
 
25
- # Define your possible departments / labels
26
- DEPARTMENTS = ["Account", "Software", "Network", "Security", "Hardware",
27
- "Infrastructure", "Licensing", "Communication", "RemoteWork",
28
- "Training", "Performance"]
29
 
30
  # -------------------------------
31
  # Routing Endpoint
@@ -36,13 +40,11 @@ async def route_ticket(req: RoutingRequest):
36
  if not text:
37
  raise HTTPException(status_code=400, detail="Text cannot be empty")
38
 
39
- # Tokenize
40
  inputs = tokenizer(text, return_tensors="pt", truncation=True)
41
  outputs = model(**inputs)
42
  logits = outputs.logits[0]
43
 
44
- # Simple mapping: choose max logit index as department (demo)
45
- # For a real hackathon, you may map labels more carefully
46
  department_idx = torch.argmax(logits).item() % len(DEPARTMENTS)
47
  department = DEPARTMENTS[department_idx]
48
 
 
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
 
7
+ # -------------------------------
8
+ # Set Hugging Face cache to writable directory
9
+ # -------------------------------
10
  os.environ["HF_HOME"] = "/data/huggingface-cache"
11
  os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface-cache"
12
 
13
+ # -------------------------------
14
+ # FastAPI app
15
+ # -------------------------------
16
+ app = FastAPI(title="Routing Service - Space 2")
17
+
18
  # -------------------------------
19
  # Request Model
20
  # -------------------------------
 
22
  text: str
23
 
24
  # -------------------------------
25
+ # Load DeBERTa MNLI Model
26
  # -------------------------------
27
  MODEL_NAME = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
29
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
30
 
31
+ # Departments mapping (example, can adjust for hackathon)
32
+ DEPARTMENTS = ["Networking", "Hardware", "Software", "Security", "General IT"]
 
 
33
 
34
  # -------------------------------
35
  # Routing Endpoint
 
40
  if not text:
41
  raise HTTPException(status_code=400, detail="Text cannot be empty")
42
 
 
43
  inputs = tokenizer(text, return_tensors="pt", truncation=True)
44
  outputs = model(**inputs)
45
  logits = outputs.logits[0]
46
 
47
+ # Simple mapping: max logit department
 
48
  department_idx = torch.argmax(logits).item() % len(DEPARTMENTS)
49
  department = DEPARTMENTS[department_idx]
50