Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| import sys | |
| sys.path.append('/home/fwl/src/textmining') | |
| from web_app.model import CEFRModel | |
| import torch | |
| print("Loading model...") | |
| model = CEFRModel(model_path='runs/metric-proto-k3/metric_proto.pt') | |
| # Test simple sentence | |
| sentences = ["Jag är bra."] | |
| print(f"\nTesting: {sentences}") | |
| # Tokenize | |
| encoded = model.tokenize(sentences) | |
| input_ids = encoded["input_ids"].to(model.device) | |
| attention_mask = encoded["attention_mask"].to(model.device) | |
| print(f"Input shape: {input_ids.shape}") | |
| print(f"Device: {model.device}") | |
| # Predict | |
| with torch.no_grad(): | |
| logits = model.model(input_ids, attention_mask)["logits"] | |
| print(f"Logits shape: {logits.shape}") | |
| print(f"Logits: {logits}") | |
| probs = torch.softmax(logits, dim=1) | |
| print(f"Probs shape: {probs.shape}") | |
| print(f"Probs: {probs}") | |
| predictions = torch.argmax(logits, dim=1) | |
| print(f"Predictions: {predictions}") | |
| # Test different ways to extract confidence | |
| cpu_probs = probs.cpu() | |
| for i, pred in enumerate(predictions.cpu().numpy()): | |
| print(f"\nSentence {i}: '{sentences[i]}'") | |
| print(f" Predicted class: {pred}") | |
| print(f" Predicted level: {model.id_to_label[pred]}") | |
| print(f" Method 1 - probs[i][pred]: {probs[i][pred].item()}") | |
| print(f" Method 2 - cpu_probs[i][pred]: {cpu_probs[i][pred].item()}") | |
| print(f" Method 3 - float(cpu_probs[i][pred].item()): {float(cpu_probs[i][pred].item())}") | |
| # Test using predict_batch | |
| print("\n" + "="*60) | |
| print("Using predict_batch method:") | |
| results = model.predict_batch(sentences) | |
| for sent, (level, conf) in zip(sentences, results): | |
| print(f" {level} ({conf*100:.1f}%): {sent}") | |
| # Test using predict_sentence | |
| print("\n" + "="*60) | |
| print("Using predict_sentence method:") | |
| level, conf = model.predict_sentence(sentences[0]) | |
| print(f" {level} ({conf*100:.1f}%): {sentences[0]}") | |