Sandei commited on
Commit
9c60f47
·
1 Parent(s): 3243c38

fixed the urgencey related error

Browse files
__pycache__/app.cpython-314.pyc ADDED
Binary file (8.16 kB). View file
 
__pycache__/memory.cpython-314.pyc ADDED
Binary file (1.22 kB). View file
 
__pycache__/models.cpython-314.pyc ADDED
Binary file (2.52 kB). View file
 
__pycache__/multi_task_model_class.cpython-314.pyc ADDED
Binary file (1.84 kB). View file
 
__pycache__/rag.cpython-314.pyc ADDED
Binary file (1.29 kB). View file
 
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
 
2
  from fastapi import FastAPI
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
 
4
 
5
  from models import (
6
  QueryRequest,
@@ -8,49 +10,100 @@ from models import (
8
  CategoryPrediction,
9
  UrgencyPrediction
10
  )
 
11
  from rag import generate_answer
12
  from memory import get_conversation, add_message
13
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  CLASSIFIER_MODEL_ID = "Sandei/tech-support-classifier"
 
 
 
17
 
18
- tag_classes = [
19
- "Billing",
20
- "Network & Connectivity",
21
- "Account Access",
22
- "Hardware",
23
- "Other"
24
- ]
25
 
26
  urgency_encoder = {
27
  0: "low",
28
  1: "medium",
29
- 2: "high"
 
30
  }
31
 
32
- tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_ID)
33
- config = AutoConfig.from_pretrained(CLASSIFIER_MODEL_ID)
34
 
35
- model = AutoModelForSequenceClassification.from_pretrained(
36
- CLASSIFIER_MODEL_ID,
37
- config=config,
38
- trust_remote_code=True
39
- ).to(DEVICE)
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model.eval()
42
 
 
 
43
  app = FastAPI(title="RAG + Conversation Memory API")
44
 
45
  # ---------------------
46
  # CLASSIFIER
47
  # ---------------------
48
  def classify_text(text: str, threshold: float = 0.5):
49
- inputs = tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
 
 
 
50
 
51
  with torch.no_grad():
52
  outputs = model(**inputs)
53
 
 
54
  category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
55
 
56
  categories = [
@@ -62,6 +115,7 @@ def classify_text(text: str, threshold: float = 0.5):
62
  if category_probs[i] >= threshold
63
  ]
64
 
 
65
  urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
66
  urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
67
 
@@ -74,6 +128,9 @@ def classify_text(text: str, threshold: float = 0.5):
74
 
75
 
76
  def retrieve_documents(query: str):
 
 
 
77
  return [
78
  "Restarting the router fixes most connectivity issues.",
79
  "Check for planned ISP maintenance.",
@@ -81,19 +138,32 @@ def retrieve_documents(query: str):
81
  ]
82
 
83
 
 
 
 
 
 
 
 
 
 
 
84
  @app.post("/query", response_model=QueryResponse)
85
  def query_endpoint(req: QueryRequest):
86
- # ---- Load conversation
 
 
 
87
  history = get_conversation(req.user_id)
88
 
89
- # ---- Classification
90
  categories, urgency = classify_text(req.query)
91
 
92
- # ---- RAG
93
  docs = retrieve_documents(req.query)
94
  answer = generate_answer(req.query, docs, history)
95
 
96
- # ---- Update memory
97
  add_message(req.user_id, "user", req.query)
98
  add_message(req.user_id, "assistant", answer)
99
 
@@ -105,3 +175,22 @@ def query_endpoint(req: QueryRequest):
105
  urgency=urgency,
106
  conversation=get_conversation(req.user_id)
107
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import os
3
  from fastapi import FastAPI
4
+ from transformers import AutoTokenizer
5
+ from huggingface_hub import hf_hub_download
6
 
7
  from models import (
8
  QueryRequest,
 
10
  CategoryPrediction,
11
  UrgencyPrediction
12
  )
13
+ from multi_task_model_class import MultiTaskModel
14
  from rag import generate_answer
15
  from memory import get_conversation, add_message
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  CLASSIFIER_MODEL_ID = "Sandei/tech-support-classifier"
20
+ ENCODER_NAME = "distilbert-base-uncased"
21
+
22
+ tag_classes = ['Email & Communication', 'Classroom/Lab Support', 'Software & Applications', 'Classroom/Lab Support', 'Classroom/Lab Support', 'Network & Connectivity', 'General IT Support', 'Data Management', 'Classroom/Lab Support', 'Security & Compliance']
23
 
 
 
 
 
 
 
 
24
 
25
  urgency_encoder = {
26
  0: "low",
27
  1: "medium",
28
+ 2: "high",
29
+ 3: "critical" # Added 4th level
30
  }
31
 
32
+ print("Loading tokenizer...")
33
+ tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_ID, trust_remote_code=True)
34
 
35
+ print("Initializing model structure...")
36
+ model = MultiTaskModel(
37
+ encoder_name=ENCODER_NAME,
38
+ num_category_labels=len(tag_classes),
39
+ num_urgency_labels=4
40
+ )
41
 
42
+ # Load model weights
43
+ print("Downloading model weights...")
44
+ try:
45
+ model_path = hf_hub_download(
46
+ repo_id=CLASSIFIER_MODEL_ID,
47
+ filename="pytorch_model.bin",
48
+ token=None, # Set to your HF token if repo is private
49
+ )
50
+ print(f"✓ Model downloaded to: {model_path}")
51
+
52
+ print("Loading model weights...")
53
+ state_dict = torch.load(model_path, map_location=DEVICE, weights_only=False)
54
+ model.load_state_dict(state_dict)
55
+ print("✓ Model weights loaded successfully")
56
+
57
+ except Exception as e:
58
+ print(f"✗ Error downloading from Hugging Face: {e}")
59
+ print("\nTrying alternative methods...")
60
+
61
+ # Method 2: Try loading from cache
62
+ from huggingface_hub import try_to_load_from_cache
63
+ cache_path = try_to_load_from_cache(
64
+ repo_id=CLASSIFIER_MODEL_ID,
65
+ filename="pytorch_model.bin"
66
+ )
67
+
68
+ if cache_path and os.path.exists(cache_path):
69
+ print(f"✓ Found in cache: {cache_path}")
70
+ state_dict = torch.load(cache_path, map_location=DEVICE, weights_only=False)
71
+ model.load_state_dict(state_dict)
72
+ print("✓ Model loaded from cache")
73
+ else:
74
+ print("\n" + "="*60)
75
+ print("ERROR: Could not load model weights")
76
+ print("="*60)
77
+ print("\nPossible solutions:")
78
+ print("1. Login to Hugging Face:")
79
+ print(" huggingface-cli login")
80
+ print("\n2. Or download manually:")
81
+ print(f" Visit: https://huggingface.co/{CLASSIFIER_MODEL_ID}/tree/main")
82
+ print(f" Download 'pytorch_model.bin' to: ./Sandei/tech-support-classifier/")
83
+ print("\n3. Check your internet connection")
84
+ print("="*60)
85
+ raise
86
+
87
+ model.to(DEVICE)
88
  model.eval()
89
 
90
+ print(f"\n✓ Model ready on {DEVICE}\n")
91
+
92
  app = FastAPI(title="RAG + Conversation Memory API")
93
 
94
  # ---------------------
95
  # CLASSIFIER
96
  # ---------------------
97
  def classify_text(text: str, threshold: float = 0.5):
98
+ """
99
+ Classify input text into categories and urgency level.
100
+ """
101
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
102
 
103
  with torch.no_grad():
104
  outputs = model(**inputs)
105
 
106
+ # Category predictions (multi-label)
107
  category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
108
 
109
  categories = [
 
115
  if category_probs[i] >= threshold
116
  ]
117
 
118
+ # Urgency prediction (multi-class)
119
  urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
120
  urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
121
 
 
128
 
129
 
130
  def retrieve_documents(query: str):
131
+ """
132
+ Retrieve relevant documents for RAG.
133
+ """
134
  return [
135
  "Restarting the router fixes most connectivity issues.",
136
  "Check for planned ISP maintenance.",
 
138
  ]
139
 
140
 
141
+ @app.get("/")
142
+ def root():
143
+ """Health check endpoint"""
144
+ return {
145
+ "status": "running",
146
+ "device": DEVICE,
147
+ "model": CLASSIFIER_MODEL_ID
148
+ }
149
+
150
+
151
  @app.post("/query", response_model=QueryResponse)
152
  def query_endpoint(req: QueryRequest):
153
+ """
154
+ Main query endpoint.
155
+ """
156
+ # Load conversation history
157
  history = get_conversation(req.user_id)
158
 
159
+ # Classification
160
  categories, urgency = classify_text(req.query)
161
 
162
+ # RAG
163
  docs = retrieve_documents(req.query)
164
  answer = generate_answer(req.query, docs, history)
165
 
166
+ # Update conversation memory
167
  add_message(req.user_id, "user", req.query)
168
  add_message(req.user_id, "assistant", answer)
169
 
 
175
  urgency=urgency,
176
  conversation=get_conversation(req.user_id)
177
  )
178
+
179
+
180
+ @app.post("/classify")
181
+ def classify_endpoint(req: QueryRequest):
182
+ """
183
+ Standalone classification endpoint.
184
+ """
185
+ categories, urgency = classify_text(req.query)
186
+
187
+ return {
188
+ "query": req.query,
189
+ "categories": categories,
190
+ "urgency": urgency
191
+ }
192
+
193
+
194
+ if __name__ == "__main__":
195
+ import uvicorn
196
+ uvicorn.run(app, host="0.0.0.0", port=8000)
multi_task_model_class.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import AutoModel
3
+
4
+
5
+ class MultiTaskModel(nn.Module):
6
+ def __init__(self, encoder_name, num_category_labels, num_urgency_labels):
7
+ super().__init__()
8
+
9
+ self.encoder = AutoModel.from_pretrained(encoder_name)
10
+ hidden_size = self.encoder.config.hidden_size
11
+
12
+ # Changed from category_head to category_classifier
13
+ self.category_classifier = nn.Linear(hidden_size, num_category_labels)
14
+ # Changed from urgency_head to urgency_classifier
15
+ self.urgency_classifier = nn.Linear(hidden_size, num_urgency_labels)
16
+
17
+ def forward(self, input_ids, attention_mask):
18
+ outputs = self.encoder(
19
+ input_ids=input_ids,
20
+ attention_mask=attention_mask
21
+ )
22
+
23
+ pooled = outputs.last_hidden_state[:, 0]
24
+
25
+ return type(
26
+ "Output",
27
+ (),
28
+ {
29
+ "category_logits": self.category_classifier(pooled),
30
+ "urgency_logits": self.urgency_classifier(pooled),
31
+ }
32
+ )()