fixed the urgencey related error
Browse files
__pycache__/app.cpython-314.pyc
ADDED
|
Binary file (8.16 kB). View file
|
|
|
__pycache__/memory.cpython-314.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
__pycache__/multi_task_model_class.cpython-314.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|
__pycache__/rag.cpython-314.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from fastapi import FastAPI
|
| 3 |
-
from transformers import AutoTokenizer
|
|
|
|
| 4 |
|
| 5 |
from models import (
|
| 6 |
QueryRequest,
|
|
@@ -8,49 +10,100 @@ from models import (
|
|
| 8 |
CategoryPrediction,
|
| 9 |
UrgencyPrediction
|
| 10 |
)
|
|
|
|
| 11 |
from rag import generate_answer
|
| 12 |
from memory import get_conversation, add_message
|
| 13 |
|
| 14 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
|
| 16 |
CLASSIFIER_MODEL_ID = "Sandei/tech-support-classifier"
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
tag_classes = [
|
| 19 |
-
"Billing",
|
| 20 |
-
"Network & Connectivity",
|
| 21 |
-
"Account Access",
|
| 22 |
-
"Hardware",
|
| 23 |
-
"Other"
|
| 24 |
-
]
|
| 25 |
|
| 26 |
urgency_encoder = {
|
| 27 |
0: "low",
|
| 28 |
1: "medium",
|
| 29 |
-
2: "high"
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
-
model
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
model.eval()
|
| 42 |
|
|
|
|
|
|
|
| 43 |
app = FastAPI(title="RAG + Conversation Memory API")
|
| 44 |
|
| 45 |
# ---------------------
|
| 46 |
# CLASSIFIER
|
| 47 |
# ---------------------
|
| 48 |
def classify_text(text: str, threshold: float = 0.5):
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
outputs = model(**inputs)
|
| 53 |
|
|
|
|
| 54 |
category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
|
| 55 |
|
| 56 |
categories = [
|
|
@@ -62,6 +115,7 @@ def classify_text(text: str, threshold: float = 0.5):
|
|
| 62 |
if category_probs[i] >= threshold
|
| 63 |
]
|
| 64 |
|
|
|
|
| 65 |
urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
|
| 66 |
urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
|
| 67 |
|
|
@@ -74,6 +128,9 @@ def classify_text(text: str, threshold: float = 0.5):
|
|
| 74 |
|
| 75 |
|
| 76 |
def retrieve_documents(query: str):
|
|
|
|
|
|
|
|
|
|
| 77 |
return [
|
| 78 |
"Restarting the router fixes most connectivity issues.",
|
| 79 |
"Check for planned ISP maintenance.",
|
|
@@ -81,19 +138,32 @@ def retrieve_documents(query: str):
|
|
| 81 |
]
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
@app.post("/query", response_model=QueryResponse)
|
| 85 |
def query_endpoint(req: QueryRequest):
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
| 87 |
history = get_conversation(req.user_id)
|
| 88 |
|
| 89 |
-
#
|
| 90 |
categories, urgency = classify_text(req.query)
|
| 91 |
|
| 92 |
-
#
|
| 93 |
docs = retrieve_documents(req.query)
|
| 94 |
answer = generate_answer(req.query, docs, history)
|
| 95 |
|
| 96 |
-
#
|
| 97 |
add_message(req.user_id, "user", req.query)
|
| 98 |
add_message(req.user_id, "assistant", answer)
|
| 99 |
|
|
@@ -105,3 +175,22 @@ def query_endpoint(req: QueryRequest):
|
|
| 105 |
urgency=urgency,
|
| 106 |
conversation=get_conversation(req.user_id)
|
| 107 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import os
|
| 3 |
from fastapi import FastAPI
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
|
| 7 |
from models import (
|
| 8 |
QueryRequest,
|
|
|
|
| 10 |
CategoryPrediction,
|
| 11 |
UrgencyPrediction
|
| 12 |
)
|
| 13 |
+
from multi_task_model_class import MultiTaskModel
|
| 14 |
from rag import generate_answer
|
| 15 |
from memory import get_conversation, add_message
|
| 16 |
|
| 17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
|
| 19 |
CLASSIFIER_MODEL_ID = "Sandei/tech-support-classifier"
|
| 20 |
+
ENCODER_NAME = "distilbert-base-uncased"
|
| 21 |
+
|
| 22 |
+
tag_classes = ['Email & Communication', 'Classroom/Lab Support', 'Software & Applications', 'Classroom/Lab Support', 'Classroom/Lab Support', 'Network & Connectivity', 'General IT Support', 'Data Management', 'Classroom/Lab Support', 'Security & Compliance']
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
urgency_encoder = {
|
| 26 |
0: "low",
|
| 27 |
1: "medium",
|
| 28 |
+
2: "high",
|
| 29 |
+
3: "critical" # Added 4th level
|
| 30 |
}
|
| 31 |
|
| 32 |
+
print("Loading tokenizer...")
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_ID, trust_remote_code=True)
|
| 34 |
|
| 35 |
+
print("Initializing model structure...")
|
| 36 |
+
model = MultiTaskModel(
|
| 37 |
+
encoder_name=ENCODER_NAME,
|
| 38 |
+
num_category_labels=len(tag_classes),
|
| 39 |
+
num_urgency_labels=4
|
| 40 |
+
)
|
| 41 |
|
| 42 |
+
# Load model weights
|
| 43 |
+
print("Downloading model weights...")
|
| 44 |
+
try:
|
| 45 |
+
model_path = hf_hub_download(
|
| 46 |
+
repo_id=CLASSIFIER_MODEL_ID,
|
| 47 |
+
filename="pytorch_model.bin",
|
| 48 |
+
token=None, # Set to your HF token if repo is private
|
| 49 |
+
)
|
| 50 |
+
print(f"✓ Model downloaded to: {model_path}")
|
| 51 |
+
|
| 52 |
+
print("Loading model weights...")
|
| 53 |
+
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
| 54 |
+
model.load_state_dict(state_dict)
|
| 55 |
+
print("✓ Model weights loaded successfully")
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"✗ Error downloading from Hugging Face: {e}")
|
| 59 |
+
print("\nTrying alternative methods...")
|
| 60 |
+
|
| 61 |
+
# Method 2: Try loading from cache
|
| 62 |
+
from huggingface_hub import try_to_load_from_cache
|
| 63 |
+
cache_path = try_to_load_from_cache(
|
| 64 |
+
repo_id=CLASSIFIER_MODEL_ID,
|
| 65 |
+
filename="pytorch_model.bin"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if cache_path and os.path.exists(cache_path):
|
| 69 |
+
print(f"✓ Found in cache: {cache_path}")
|
| 70 |
+
state_dict = torch.load(cache_path, map_location=DEVICE, weights_only=False)
|
| 71 |
+
model.load_state_dict(state_dict)
|
| 72 |
+
print("✓ Model loaded from cache")
|
| 73 |
+
else:
|
| 74 |
+
print("\n" + "="*60)
|
| 75 |
+
print("ERROR: Could not load model weights")
|
| 76 |
+
print("="*60)
|
| 77 |
+
print("\nPossible solutions:")
|
| 78 |
+
print("1. Login to Hugging Face:")
|
| 79 |
+
print(" huggingface-cli login")
|
| 80 |
+
print("\n2. Or download manually:")
|
| 81 |
+
print(f" Visit: https://huggingface.co/{CLASSIFIER_MODEL_ID}/tree/main")
|
| 82 |
+
print(f" Download 'pytorch_model.bin' to: ./Sandei/tech-support-classifier/")
|
| 83 |
+
print("\n3. Check your internet connection")
|
| 84 |
+
print("="*60)
|
| 85 |
+
raise
|
| 86 |
+
|
| 87 |
+
model.to(DEVICE)
|
| 88 |
model.eval()
|
| 89 |
|
| 90 |
+
print(f"\n✓ Model ready on {DEVICE}\n")
|
| 91 |
+
|
| 92 |
app = FastAPI(title="RAG + Conversation Memory API")
|
| 93 |
|
| 94 |
# ---------------------
|
| 95 |
# CLASSIFIER
|
| 96 |
# ---------------------
|
| 97 |
def classify_text(text: str, threshold: float = 0.5):
|
| 98 |
+
"""
|
| 99 |
+
Classify input text into categories and urgency level.
|
| 100 |
+
"""
|
| 101 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
| 102 |
|
| 103 |
with torch.no_grad():
|
| 104 |
outputs = model(**inputs)
|
| 105 |
|
| 106 |
+
# Category predictions (multi-label)
|
| 107 |
category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
|
| 108 |
|
| 109 |
categories = [
|
|
|
|
| 115 |
if category_probs[i] >= threshold
|
| 116 |
]
|
| 117 |
|
| 118 |
+
# Urgency prediction (multi-class)
|
| 119 |
urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
|
| 120 |
urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
|
| 121 |
|
|
|
|
| 128 |
|
| 129 |
|
| 130 |
def retrieve_documents(query: str):
|
| 131 |
+
"""
|
| 132 |
+
Retrieve relevant documents for RAG.
|
| 133 |
+
"""
|
| 134 |
return [
|
| 135 |
"Restarting the router fixes most connectivity issues.",
|
| 136 |
"Check for planned ISP maintenance.",
|
|
|
|
| 138 |
]
|
| 139 |
|
| 140 |
|
| 141 |
+
@app.get("/")
|
| 142 |
+
def root():
|
| 143 |
+
"""Health check endpoint"""
|
| 144 |
+
return {
|
| 145 |
+
"status": "running",
|
| 146 |
+
"device": DEVICE,
|
| 147 |
+
"model": CLASSIFIER_MODEL_ID
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
@app.post("/query", response_model=QueryResponse)
|
| 152 |
def query_endpoint(req: QueryRequest):
|
| 153 |
+
"""
|
| 154 |
+
Main query endpoint.
|
| 155 |
+
"""
|
| 156 |
+
# Load conversation history
|
| 157 |
history = get_conversation(req.user_id)
|
| 158 |
|
| 159 |
+
# Classification
|
| 160 |
categories, urgency = classify_text(req.query)
|
| 161 |
|
| 162 |
+
# RAG
|
| 163 |
docs = retrieve_documents(req.query)
|
| 164 |
answer = generate_answer(req.query, docs, history)
|
| 165 |
|
| 166 |
+
# Update conversation memory
|
| 167 |
add_message(req.user_id, "user", req.query)
|
| 168 |
add_message(req.user_id, "assistant", answer)
|
| 169 |
|
|
|
|
| 175 |
urgency=urgency,
|
| 176 |
conversation=get_conversation(req.user_id)
|
| 177 |
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@app.post("/classify")
|
| 181 |
+
def classify_endpoint(req: QueryRequest):
|
| 182 |
+
"""
|
| 183 |
+
Standalone classification endpoint.
|
| 184 |
+
"""
|
| 185 |
+
categories, urgency = classify_text(req.query)
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"query": req.query,
|
| 189 |
+
"categories": categories,
|
| 190 |
+
"urgency": urgency
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
import uvicorn
|
| 196 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
multi_task_model_class.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from transformers import AutoModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MultiTaskModel(nn.Module):
|
| 6 |
+
def __init__(self, encoder_name, num_category_labels, num_urgency_labels):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.encoder = AutoModel.from_pretrained(encoder_name)
|
| 10 |
+
hidden_size = self.encoder.config.hidden_size
|
| 11 |
+
|
| 12 |
+
# Changed from category_head to category_classifier
|
| 13 |
+
self.category_classifier = nn.Linear(hidden_size, num_category_labels)
|
| 14 |
+
# Changed from urgency_head to urgency_classifier
|
| 15 |
+
self.urgency_classifier = nn.Linear(hidden_size, num_urgency_labels)
|
| 16 |
+
|
| 17 |
+
def forward(self, input_ids, attention_mask):
|
| 18 |
+
outputs = self.encoder(
|
| 19 |
+
input_ids=input_ids,
|
| 20 |
+
attention_mask=attention_mask
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
pooled = outputs.last_hidden_state[:, 0]
|
| 24 |
+
|
| 25 |
+
return type(
|
| 26 |
+
"Output",
|
| 27 |
+
(),
|
| 28 |
+
{
|
| 29 |
+
"category_logits": self.category_classifier(pooled),
|
| 30 |
+
"urgency_logits": self.urgency_classifier(pooled),
|
| 31 |
+
}
|
| 32 |
+
)()
|