# backend/agents/classification_agent.py import os import time import torch from transformers import pipeline # --- REMOVED HARDCODED STORAGE CONFIG --- # Hugging Face Spaces & Docker will manage the HF_HOME cache path automatically. # DO NOT set 'D:/' paths here anymore. # --- YOUR EXACT CLASS --- class ClassificationAgent: def __init__(self): # --- DYNAMIC PATH RESOLUTION --- # Automatically works on local Windows and Linux Docker container base_dir = os.path.dirname(os.path.abspath(__file__)) self.model_path = os.path.join(base_dir, 'classification_model') # ------------------------------- self.device = 0 if torch.cuda.is_available() else -1 print(f"--- Loading Specialist Agent on {'GPU' if self.device == 0 else 'CPU'} ---") self.classifier = pipeline( "text-classification", model=self.model_path, tokenizer=self.model_path, device=self.device, truncation=True, max_length=512 ) self.id2label = { "LABEL_0": "Work", "LABEL_1": "Personal", "LABEL_2": "Finance", "LABEL_3": "Travel", "LABEL_4": "Social" } def process(self, subject: str, body: str): text_input = f"Subject: {subject} | Content: {body}" start_time = time.perf_counter() result = self.classifier(text_input)[0] end_time = time.perf_counter() predicted_label = self.id2label.get(result['label'], "Uncategorized") return { "category": predicted_label, "confidence": round(result['score'], 4), "metrics": { "latency_s": round(end_time - start_time, 4), "model": "classification_model", "processor": "gpu" if self.device == 0 else "cpu" } } # --- FASTAPI WRAPPERS --- classification_agent_instance = None def load_classification_model(): """Called ONCE by FastAPI when the server starts.""" global classification_agent_instance classification_agent_instance = ClassificationAgent() print("🗂️ Classification Agent ready.") async def run_classification(subject: str, text: str) -> dict: """Called on every single email request.""" global classification_agent_instance if classification_agent_instance is None: raise RuntimeError("Classification model is not loaded into memory.") # Run the pipeline inference return classification_agent_instance.process(subject, text)