import argparse import json import torch import torch.nn.functional as F from typing import Dict, List, Tuple from torch.utils.data import DataLoader # Assuming these are in your c1.py from c1 import ( IMDBDataset, TransformerClassifier, preprocess_data, evaluate_model, load_imdb_texts, MODEL_PATH, ) # You would need to install openai: pip install openai from openai import OpenAI api_file = "/home/mshahidul/api_new.json" with open(api_file, "r") as f: api_keys = json.load(f) openai_api_key = api_keys["openai"] client = OpenAI(api_key=openai_api_key) # Initialize your client (ensure your API key is in your environment variables) def get_llm_explanation(review_text: str, true_y: int, pred_y: int) -> str: """ Uses an LLM to perform qualitative reasoning on why the model failed. """ sentiment = {0: "Negative", 1: "Positive"} prompt = f""" A Transformer model misclassified the following movie review. REVIEW: "{review_text[:1000]}" TRUE LABEL: {sentiment[true_y]} MODEL PREDICTED: {sentiment[pred_y]} Task: Provide a concise (2-3 sentence) explanation of why a machine learning model might have struggled with this specific text. Mention linguistic features like sarcasm, double negatives, mixed sentiment, or specific keywords. """ try: response = client.chat.completions.create( model="gpt-4o-mini", # Using 4o-mini as a high-performance proxy for "mini" models messages=[{"role": "user", "content": prompt}], temperature=0.2 ) return response.choices[0].message.content.strip() except Exception as e: return f"LLM Analysis failed: {str(e)}" def analyze_misclassifications_on_texts( model: torch.nn.Module, texts: List[str], labels: List[int], vocab: Dict[str, int], max_len: int, device: torch.device, num_examples: int = 10, ) -> List[Dict]: """ Identifies errors, generates LLM explanations, and returns structured results. """ model.eval() sequences = preprocess_data(texts, vocab, max_len) dataset = IMDBDataset(sequences, labels) loader = DataLoader(dataset, batch_size=64, shuffle=False) error_results = [] printed = 0 with torch.no_grad(): for batch_idx, (batch_seq, batch_lab) in enumerate(loader): batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device) logits = model(batch_seq) probs = F.softmax(logits, dim=1) preds = torch.argmax(probs, dim=1) start = batch_idx * loader.batch_size batch_texts = texts[start:start + batch_seq.size(0)] for text, true_y, pred_y, prob_vec in zip( batch_texts, batch_lab.cpu().numpy(), preds.cpu().numpy(), probs.cpu().numpy(), ): if true_y != pred_y: printed += 1 print(f"Analyzing error #{printed} with LLM...") explanation = get_llm_explanation(text, true_y, pred_y) error_entry = { "example_id": printed, "true_label": int(true_y), "predicted_label": int(pred_y), "confidence_neg": float(prob_vec[0]), "confidence_pos": float(prob_vec[1]), "text": text, "explanation": explanation } error_results.append(error_entry) # Print to console for immediate feedback print("=" * 80) print(f"True: {true_y} | Pred: {pred_y}") print(f"Reasoning: {explanation}") print("=" * 80) if printed >= num_examples: return error_results return error_results def load_trained_model_from_checkpoint( checkpoint_path: str = MODEL_PATH, device: torch.device | None = None, ) -> Tuple[torch.nn.Module, Dict[str, int], Dict]: if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = torch.load(checkpoint_path, map_location=device) vocab = ckpt["vocab"] config = ckpt["config"] model = TransformerClassifier( vocab_size=len(vocab), d_model=config["d_model"], num_heads=config["num_heads"], num_layers=config["num_layers"], d_ff=config["d_ff"], max_len=config["max_len"], ).to(device) model.load_state_dict(ckpt["model_state_dict"]) return model, vocab, config def main(): parser = argparse.ArgumentParser() parser.add_argument("--split", type=str, default="test") parser.add_argument("--num_examples", type=int, default=10) parser.add_argument("--output", type=str, default="error_analysis.json") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. Load Model model, vocab, config = load_trained_model_from_checkpoint(device=device) # 2. Load Data texts, labels = load_imdb_texts(split=args.split) # 3. Analyze errors = analyze_misclassifications_on_texts( model=model, texts=texts, labels=labels, vocab=vocab, max_len=config["max_len"], device=device, num_examples=args.num_examples ) # 4. Save Results with open(args.output, "w") as f: json.dump(errors, f, indent=4) print(f"\nAnalysis complete. Results saved to {args.output}") if __name__ == "__main__": main()