twissamodi's picture
fix issues
26c1f9a
import os
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
import torch
# Set HF endpoint explicitly to avoid DNS issues
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
MODEL_BASE = "Qwen/Qwen2.5-0.5B"
PEFT_MODEL = "twissamodi/qwen2.5-banking77-intent-classifier"
LABEL_NAMES = [
"activate_my_card", "age_limit", "apple_pay_or_google_pay", "atm_support",
"automatic_top_up", "balance_not_updated_after_bank_transfer",
"balance_not_updated_after_cheque_or_cash_deposit", "beneficiary_not_allowed",
"cancel_transfer", "card_about_to_expire", "card_acceptance", "card_arrival",
"card_delivery_estimate", "card_linking", "card_not_working",
"card_payment_fee_charged", "card_payment_not_recognised",
"card_payment_wrong_exchange_rate", "card_swallowed", "cash_withdrawal_charge",
"cash_withdrawal_not_recognised", "change_pin", "compromised_card",
"contactless_not_working", "country_support", "declined_card_payment",
"declined_cash_withdrawal", "declined_transfer",
"direct_debit_payment_not_recognised", "disposable_card_limits",
"edit_personal_details", "exchange_charge", "exchange_rate", "exchange_via_app",
"extra_charge_on_statement", "failed_transfer", "fiat_currency_support",
"get_disposable_virtual_card", "get_physical_card", "getting_spare_card",
"getting_virtual_card", "lost_or_stolen_card", "lost_or_stolen_phone",
"order_physical_card", "passcode_forgotten", "pending_card_payment",
"pending_cash_withdrawal", "pending_top_up", "pending_transfer", "pin_blocked",
"receiving_money", "Refund_not_showing_up", "request_refund",
"reverted_card_payment?", "supported_cards_and_currencies", "terminate_account",
"top_up_by_bank_transfer_charge", "top_up_by_card_charge",
"top_up_by_cash_or_cheque", "top_up_failed", "top_up_limits", "top_up_reverted",
"topping_up_by_card", "transaction_charged_twice", "transfer_fee_charged",
"transfer_into_account", "transfer_not_received_by_recipient", "transfer_timing",
"unable_to_verify_identity", "verify_my_identity", "verify_source_of_funds",
"verify_top_up", "virtual_card_not_working", "visa_or_mastercard",
"why_verify_identity", "wrong_amount_of_cash_received",
"wrong_exchange_rate_for_cash_withdrawal",
"unknown"
]
THRESHOLD = 40.0
class IntentClassifier:
def __init__(self):
print("Loading classifier...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Add retry logic for network issues
max_retries = 3
for attempt in range(max_retries):
try:
self.tokenizer = AutoTokenizer.from_pretrained(
MODEL_BASE,
local_files_only=False,
trust_remote_code=True
)
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load TWO separate base models to avoid PEFT contamination
# 1. One for fine-tuning (will be wrapped by PEFT)
self.base_model_for_peft = AutoModelForSequenceClassification.from_pretrained(
MODEL_BASE,
num_labels=len(LABEL_NAMES),
torch_dtype=torch.float16,
device_map="cpu"
)
self.base_model_for_peft.eval()
# 2. One for zero-shot comparison (keep separate, untouched by PEFT)
self.base_model = AutoModelForSequenceClassification.from_pretrained(
MODEL_BASE,
num_labels=len(LABEL_NAMES),
torch_dtype=torch.float16,
device_map="cpu"
)
self.base_model.eval()
# Apply PEFT only to the first base model
self.model = PeftModel.from_pretrained(
self.base_model_for_peft,
PEFT_MODEL,
local_files_only=False
)
self.model.eval()
print("Classifier loaded!")
break
except Exception as e:
if attempt < max_retries - 1:
print(f"Attempt {attempt + 1}/{max_retries} failed: {e}. Retrying in 5s...")
time.sleep(5)
else:
print(f"Failed to load models after {max_retries} attempts: {e}")
raise
def classify(self, text: str) -> dict:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
top3 = torch.topk(probs, 3)
results = [
{
"intent": LABEL_NAMES[idx.item()],
"confidence": round(score.item() * 100, 2)
}
for score, idx in zip(top3.values[0], top3.indices[0])
]
if results[0]["intent"] == "unknown" or results[0]["confidence"] < THRESHOLD:
return {
"top_intent": "unknown",
"confidence": results[0]["confidence"],
"top3": results,
}
return {
"top_intent": results[0]["intent"],
"confidence": results[0]["confidence"],
"top3": results
}
class ZeroShotClassifier:
"""
Uses the base Qwen model (without PEFT fine-tuning) as a baseline
for comparison with the fine-tuned classifier in the /compare endpoint.
Reuses the tokenizer from IntentClassifier to save memory.
"""
def __init__(self, tokenizer, model):
print("Zero-shot classifier ready (base model without fine-tuning).")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = tokenizer
self.model = model
def classify(self, text: str) -> dict:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
top3 = torch.topk(probs, 3)
results = [
{
"intent": LABEL_NAMES[idx.item()],
"confidence": round(score.item() * 100, 2)
}
for score, idx in zip(top3.values[0], top3.indices[0])
]
return {
"top_intent": results[0]["intent"],
"confidence": results[0]["confidence"],
"top3": results,
"fallback": False,
"fallback_message": None
}