PFE_project_backend / agents /classification_agent.py
Ayoubouba's picture
Upload 17 files
0ee60d8 verified
# backend/agents/classification_agent.py
import os
import time
import torch
from transformers import pipeline
# --- REMOVED HARDCODED STORAGE CONFIG ---
# Hugging Face Spaces & Docker will manage the HF_HOME cache path automatically.
# DO NOT set 'D:/' paths here anymore.
# --- YOUR EXACT CLASS ---
class ClassificationAgent:
def __init__(self):
# --- DYNAMIC PATH RESOLUTION ---
# Automatically works on local Windows and Linux Docker container
base_dir = os.path.dirname(os.path.abspath(__file__))
self.model_path = os.path.join(base_dir, 'classification_model')
# -------------------------------
self.device = 0 if torch.cuda.is_available() else -1
print(f"--- Loading Specialist Agent on {'GPU' if self.device == 0 else 'CPU'} ---")
self.classifier = pipeline(
"text-classification",
model=self.model_path,
tokenizer=self.model_path,
device=self.device,
truncation=True,
max_length=512
)
self.id2label = {
"LABEL_0": "Work",
"LABEL_1": "Personal",
"LABEL_2": "Finance",
"LABEL_3": "Travel",
"LABEL_4": "Social"
}
def process(self, subject: str, body: str):
text_input = f"Subject: {subject} | Content: {body}"
start_time = time.perf_counter()
result = self.classifier(text_input)[0]
end_time = time.perf_counter()
predicted_label = self.id2label.get(result['label'], "Uncategorized")
return {
"category": predicted_label,
"confidence": round(result['score'], 4),
"metrics": {
"latency_s": round(end_time - start_time, 4),
"model": "classification_model",
"processor": "gpu" if self.device == 0 else "cpu"
}
}
# --- FASTAPI WRAPPERS ---
classification_agent_instance = None
def load_classification_model():
"""Called ONCE by FastAPI when the server starts."""
global classification_agent_instance
classification_agent_instance = ClassificationAgent()
print("🗂️ Classification Agent ready.")
async def run_classification(subject: str, text: str) -> dict:
"""Called on every single email request."""
global classification_agent_instance
if classification_agent_instance is None:
raise RuntimeError("Classification model is not loaded into memory.")
# Run the pipeline inference
return classification_agent_instance.process(subject, text)