| import argparse |
| import json |
| import torch |
| import torch.nn.functional as F |
| from typing import Dict, List, Tuple |
| from torch.utils.data import DataLoader |
|
|
| |
| from c1 import ( |
| IMDBDataset, |
| TransformerClassifier, |
| preprocess_data, |
| evaluate_model, |
| load_imdb_texts, |
| MODEL_PATH, |
| ) |
|
|
| |
| from openai import OpenAI |
| api_file = "/home/mshahidul/api_new.json" |
| with open(api_file, "r") as f: |
| api_keys = json.load(f) |
| openai_api_key = api_keys["openai"] |
|
|
| client = OpenAI(api_key=openai_api_key) |
| |
|
|
| def get_llm_explanation(review_text: str, true_y: int, pred_y: int) -> str: |
| """ |
| Uses an LLM to perform qualitative reasoning on why the model failed. |
| """ |
| sentiment = {0: "Negative", 1: "Positive"} |
| |
| prompt = f""" |
| A Transformer model misclassified the following movie review. |
| |
| REVIEW: "{review_text[:1000]}" |
| TRUE LABEL: {sentiment[true_y]} |
| MODEL PREDICTED: {sentiment[pred_y]} |
| |
| Task: Provide a concise (2-3 sentence) explanation of why a machine learning |
| model might have struggled with this specific text. Mention linguistic |
| features like sarcasm, double negatives, mixed sentiment, or specific keywords. |
| """ |
|
|
| try: |
| response = client.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0.2 |
| ) |
| return response.choices[0].message.content.strip() |
| except Exception as e: |
| return f"LLM Analysis failed: {str(e)}" |
|
|
| def analyze_misclassifications_on_texts( |
| model: torch.nn.Module, |
| texts: List[str], |
| labels: List[int], |
| vocab: Dict[str, int], |
| max_len: int, |
| device: torch.device, |
| num_examples: int = 10, |
| ) -> List[Dict]: |
| """ |
| Identifies errors, generates LLM explanations, and returns structured results. |
| """ |
| model.eval() |
| sequences = preprocess_data(texts, vocab, max_len) |
| dataset = IMDBDataset(sequences, labels) |
| loader = DataLoader(dataset, batch_size=64, shuffle=False) |
|
|
| error_results = [] |
| printed = 0 |
| |
| with torch.no_grad(): |
| for batch_idx, (batch_seq, batch_lab) in enumerate(loader): |
| batch_seq, batch_lab = batch_seq.to(device), batch_lab.to(device) |
| logits = model(batch_seq) |
| probs = F.softmax(logits, dim=1) |
| preds = torch.argmax(probs, dim=1) |
|
|
| start = batch_idx * loader.batch_size |
| batch_texts = texts[start:start + batch_seq.size(0)] |
|
|
| for text, true_y, pred_y, prob_vec in zip( |
| batch_texts, |
| batch_lab.cpu().numpy(), |
| preds.cpu().numpy(), |
| probs.cpu().numpy(), |
| ): |
| if true_y != pred_y: |
| printed += 1 |
| |
| print(f"Analyzing error #{printed} with LLM...") |
| explanation = get_llm_explanation(text, true_y, pred_y) |
| |
| error_entry = { |
| "example_id": printed, |
| "true_label": int(true_y), |
| "predicted_label": int(pred_y), |
| "confidence_neg": float(prob_vec[0]), |
| "confidence_pos": float(prob_vec[1]), |
| "text": text, |
| "explanation": explanation |
| } |
| error_results.append(error_entry) |
|
|
| |
| print("=" * 80) |
| print(f"True: {true_y} | Pred: {pred_y}") |
| print(f"Reasoning: {explanation}") |
| print("=" * 80) |
|
|
| if printed >= num_examples: |
| return error_results |
|
|
| return error_results |
|
|
| def load_trained_model_from_checkpoint( |
| checkpoint_path: str = MODEL_PATH, |
| device: torch.device | None = None, |
| ) -> Tuple[torch.nn.Module, Dict[str, int], Dict]: |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| ckpt = torch.load(checkpoint_path, map_location=device) |
| vocab = ckpt["vocab"] |
| config = ckpt["config"] |
|
|
| model = TransformerClassifier( |
| vocab_size=len(vocab), |
| d_model=config["d_model"], |
| num_heads=config["num_heads"], |
| num_layers=config["num_layers"], |
| d_ff=config["d_ff"], |
| max_len=config["max_len"], |
| ).to(device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| return model, vocab, config |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--split", type=str, default="test") |
| parser.add_argument("--num_examples", type=int, default=10) |
| parser.add_argument("--output", type=str, default="error_analysis.json") |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model, vocab, config = load_trained_model_from_checkpoint(device=device) |
|
|
| |
| texts, labels = load_imdb_texts(split=args.split) |
|
|
| |
| errors = analyze_misclassifications_on_texts( |
| model=model, |
| texts=texts, |
| labels=labels, |
| vocab=vocab, |
| max_len=config["max_len"], |
| device=device, |
| num_examples=args.num_examples |
| ) |
|
|
| |
| with open(args.output, "w") as f: |
| json.dump(errors, f, indent=4) |
| print(f"\nAnalysis complete. Results saved to {args.output}") |
|
|
| if __name__ == "__main__": |
| main() |