File size: 5,197 Bytes
81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 a5d886c 81b1a96 a5d886c 81b1a96 3243c38 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 81b1a96 9c60f47 c39871c 81b1a96 9c60f47 81b1a96 9c60f47 d74109c 39d9710 d74109c 9c60f47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | import torch
import os
from fastapi import FastAPI
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from models import (
QueryRequest,
QueryResponse,
CategoryPrediction,
UrgencyPrediction
)
from multi_task_model_class import MultiTaskModel
from memory import get_conversation, add_message
from service.rag_service import generate_answer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASSIFIER_MODEL_ID = "Sandei/tech-support-classifier"
ENCODER_NAME = "distilbert-base-uncased"
tag_classes = ['Email & Communication', 'Classroom/Lab Support', 'Software & Applications', 'Classroom/Lab Support', 'Classroom/Lab Support', 'Network & Connectivity', 'General IT Support', 'Data Management', 'Classroom/Lab Support', 'Security & Compliance']
urgency_encoder = {
0: "low",
1: "medium",
2: "high",
3: "critical" # Added 4th level
}
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_ID, trust_remote_code=True)
print("Initializing model structure...")
model = MultiTaskModel(
encoder_name=ENCODER_NAME,
num_category_labels=len(tag_classes),
num_urgency_labels=4
)
# Load model weights
print("Downloading model weights...")
try:
model_path = hf_hub_download(
repo_id=CLASSIFIER_MODEL_ID,
filename="pytorch_model.bin",
token=None, # Set to your HF token if repo is private
)
print(f"β Model downloaded to: {model_path}")
print("Loading model weights...")
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(state_dict)
print("β Model weights loaded successfully")
except Exception as e:
print(f"β Error downloading from Hugging Face: {e}")
print("\nTrying alternative methods...")
# Method 2: Try loading from cache
from huggingface_hub import try_to_load_from_cache
cache_path = try_to_load_from_cache(
repo_id=CLASSIFIER_MODEL_ID,
filename="pytorch_model.bin"
)
if cache_path and os.path.exists(cache_path):
print(f"β Found in cache: {cache_path}")
state_dict = torch.load(cache_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(state_dict)
print("β Model loaded from cache")
else:
print("\n" + "="*60)
print("ERROR: Could not load model weights")
print("="*60)
print("\nPossible solutions:")
print("1. Login to Hugging Face:")
print(" huggingface-cli login")
print("\n2. Or download manually:")
print(f" Visit: https://huggingface.co/{CLASSIFIER_MODEL_ID}/tree/main")
print(f" Download 'pytorch_model.bin' to: ./Sandei/tech-support-classifier/")
print("\n3. Check your internet connection")
print("="*60)
raise
model.to(DEVICE)
model.eval()
print(f"\nβ Model ready on {DEVICE}\n")
app = FastAPI(title="RAG + Conversation Memory API")
# ---------------------
# CLASSIFIER
# ---------------------
def classify_text(text: str, threshold: float = 0.5):
"""
Classify input text into categories and urgency level.
"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
# Category predictions (multi-label)
category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
categories = [
CategoryPrediction(
category=tag_classes[i],
confidence=float(category_probs[i])
)
for i in range(len(tag_classes))
if category_probs[i] >= threshold
]
# Urgency prediction (multi-class)
urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
urgency = UrgencyPrediction(
label=urgency_encoder[urgency_idx],
confidence=float(urgency_probs[urgency_idx])
)
return categories, urgency
@app.get("/")
def root():
"""Health check endpoint"""
return {
"status": "running",
"device": DEVICE,
"model": CLASSIFIER_MODEL_ID
}
@app.post("/query", response_model=QueryResponse)
def query_endpoint(req: QueryRequest):
"""
Main query endpoint.
"""
# Load conversation history
# Classification
categories, urgency = classify_text(req.query)
# RAG
answer = generate_answer(req.query)
# Update conversation memory
return QueryResponse(
user_id=req.user_id,
query=req.query,
answer=answer,
categories=categories,
urgency=urgency,
)
@app.post("/classify")
def classify_endpoint(req: QueryRequest):
"""
Standalone classification endpoint.
"""
categories, urgency = classify_text(req.query)
return {
"query": req.query,
"categories": categories,
"urgency": urgency
}
@app.on_event("startup")
def warmup():
classify_text("hello")
generate_answer("test")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |