Sandei's picture
fix 2
39d9710
import torch
import os
from fastapi import FastAPI
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from models import (
QueryRequest,
QueryResponse,
CategoryPrediction,
UrgencyPrediction
)
from multi_task_model_class import MultiTaskModel
from memory import get_conversation, add_message
from service.rag_service import generate_answer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLASSIFIER_MODEL_ID = "Sandei/tech-support-classifier"
ENCODER_NAME = "distilbert-base-uncased"
tag_classes = ['Email & Communication', 'Classroom/Lab Support', 'Software & Applications', 'Classroom/Lab Support', 'Classroom/Lab Support', 'Network & Connectivity', 'General IT Support', 'Data Management', 'Classroom/Lab Support', 'Security & Compliance']
urgency_encoder = {
0: "low",
1: "medium",
2: "high",
3: "critical" # Added 4th level
}
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_ID, trust_remote_code=True)
print("Initializing model structure...")
model = MultiTaskModel(
encoder_name=ENCODER_NAME,
num_category_labels=len(tag_classes),
num_urgency_labels=4
)
# Load model weights
print("Downloading model weights...")
try:
model_path = hf_hub_download(
repo_id=CLASSIFIER_MODEL_ID,
filename="pytorch_model.bin",
token=None, # Set to your HF token if repo is private
)
print(f"βœ“ Model downloaded to: {model_path}")
print("Loading model weights...")
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(state_dict)
print("βœ“ Model weights loaded successfully")
except Exception as e:
print(f"βœ— Error downloading from Hugging Face: {e}")
print("\nTrying alternative methods...")
# Method 2: Try loading from cache
from huggingface_hub import try_to_load_from_cache
cache_path = try_to_load_from_cache(
repo_id=CLASSIFIER_MODEL_ID,
filename="pytorch_model.bin"
)
if cache_path and os.path.exists(cache_path):
print(f"βœ“ Found in cache: {cache_path}")
state_dict = torch.load(cache_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(state_dict)
print("βœ“ Model loaded from cache")
else:
print("\n" + "="*60)
print("ERROR: Could not load model weights")
print("="*60)
print("\nPossible solutions:")
print("1. Login to Hugging Face:")
print(" huggingface-cli login")
print("\n2. Or download manually:")
print(f" Visit: https://huggingface.co/{CLASSIFIER_MODEL_ID}/tree/main")
print(f" Download 'pytorch_model.bin' to: ./Sandei/tech-support-classifier/")
print("\n3. Check your internet connection")
print("="*60)
raise
model.to(DEVICE)
model.eval()
print(f"\nβœ“ Model ready on {DEVICE}\n")
app = FastAPI(title="RAG + Conversation Memory API")
# ---------------------
# CLASSIFIER
# ---------------------
def classify_text(text: str, threshold: float = 0.5):
"""
Classify input text into categories and urgency level.
"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
# Category predictions (multi-label)
category_probs = torch.sigmoid(outputs.category_logits)[0].cpu().numpy()
categories = [
CategoryPrediction(
category=tag_classes[i],
confidence=float(category_probs[i])
)
for i in range(len(tag_classes))
if category_probs[i] >= threshold
]
# Urgency prediction (multi-class)
urgency_probs = torch.softmax(outputs.urgency_logits, dim=-1)[0].cpu().numpy()
urgency_idx = int(torch.argmax(outputs.urgency_logits, dim=-1)[0])
urgency = UrgencyPrediction(
label=urgency_encoder[urgency_idx],
confidence=float(urgency_probs[urgency_idx])
)
return categories, urgency
@app.get("/")
def root():
"""Health check endpoint"""
return {
"status": "running",
"device": DEVICE,
"model": CLASSIFIER_MODEL_ID
}
@app.post("/query", response_model=QueryResponse)
def query_endpoint(req: QueryRequest):
"""
Main query endpoint.
"""
# Load conversation history
# Classification
categories, urgency = classify_text(req.query)
# RAG
answer = generate_answer(req.query)
# Update conversation memory
return QueryResponse(
user_id=req.user_id,
query=req.query,
answer=answer,
categories=categories,
urgency=urgency,
)
@app.post("/classify")
def classify_endpoint(req: QueryRequest):
"""
Standalone classification endpoint.
"""
categories, urgency = classify_text(req.query)
return {
"query": req.query,
"categories": categories,
"urgency": urgency
}
@app.on_event("startup")
def warmup():
classify_text("hello")
generate_answer("test")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)