Spaces:
Sleeping
Sleeping
File size: 2,691 Bytes
0ee60d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | # backend/agents/classification_agent.py
import os
import time
import torch
from transformers import pipeline
# --- REMOVED HARDCODED STORAGE CONFIG ---
# Hugging Face Spaces & Docker will manage the HF_HOME cache path automatically.
# DO NOT set 'D:/' paths here anymore.
# --- YOUR EXACT CLASS ---
class ClassificationAgent:
def __init__(self):
# --- DYNAMIC PATH RESOLUTION ---
# Automatically works on local Windows and Linux Docker container
base_dir = os.path.dirname(os.path.abspath(__file__))
self.model_path = os.path.join(base_dir, 'classification_model')
# -------------------------------
self.device = 0 if torch.cuda.is_available() else -1
print(f"--- Loading Specialist Agent on {'GPU' if self.device == 0 else 'CPU'} ---")
self.classifier = pipeline(
"text-classification",
model=self.model_path,
tokenizer=self.model_path,
device=self.device,
truncation=True,
max_length=512
)
self.id2label = {
"LABEL_0": "Work",
"LABEL_1": "Personal",
"LABEL_2": "Finance",
"LABEL_3": "Travel",
"LABEL_4": "Social"
}
def process(self, subject: str, body: str):
text_input = f"Subject: {subject} | Content: {body}"
start_time = time.perf_counter()
result = self.classifier(text_input)[0]
end_time = time.perf_counter()
predicted_label = self.id2label.get(result['label'], "Uncategorized")
return {
"category": predicted_label,
"confidence": round(result['score'], 4),
"metrics": {
"latency_s": round(end_time - start_time, 4),
"model": "classification_model",
"processor": "gpu" if self.device == 0 else "cpu"
}
}
# --- FASTAPI WRAPPERS ---
classification_agent_instance = None
def load_classification_model():
"""Called ONCE by FastAPI when the server starts."""
global classification_agent_instance
classification_agent_instance = ClassificationAgent()
print("🗂️ Classification Agent ready.")
async def run_classification(subject: str, text: str) -> dict:
"""Called on every single email request."""
global classification_agent_instance
if classification_agent_instance is None:
raise RuntimeError("Classification model is not loaded into memory.")
# Run the pipeline inference
return classification_agent_instance.process(subject, text) |