|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
_action_class_names = None |
|
|
|
|
|
def get_action_class_names(): |
|
|
"""Load action class names""" |
|
|
global _action_class_names |
|
|
if _action_class_names is None: |
|
|
|
|
|
model_dir = Path(__file__).parent / 'models' |
|
|
|
|
|
with open(model_dir / 'action_model_config.pkl', 'rb') as f: |
|
|
config = pickle.load(f) |
|
|
_action_class_names = config['class_names'] |
|
|
return _action_class_names |
|
|
|
|
|
def generate_caption(model, image, vocab, device, max_length=30): |
|
|
""" |
|
|
Generate caption for an image |
|
|
|
|
|
Args: |
|
|
model: Trained caption model |
|
|
image: PIL Image |
|
|
vocab: Vocabulary object |
|
|
device: torch device |
|
|
max_length: Maximum caption length |
|
|
|
|
|
Returns: |
|
|
caption: Generated caption string |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
caption_indices = model.generate_caption(image_tensor, max_length) |
|
|
|
|
|
|
|
|
caption_indices = caption_indices[0].cpu().numpy() |
|
|
caption_words = vocab.decode(caption_indices) |
|
|
|
|
|
|
|
|
caption = [] |
|
|
for word in caption_words: |
|
|
if word == vocab.start_token: |
|
|
continue |
|
|
if word == vocab.end_token: |
|
|
break |
|
|
if word == vocab.pad_token: |
|
|
break |
|
|
caption.append(word) |
|
|
|
|
|
caption_text = ' '.join(caption) |
|
|
|
|
|
|
|
|
if caption_text: |
|
|
caption_text = caption_text[0].upper() + caption_text[1:] |
|
|
|
|
|
return caption_text |
|
|
|
|
|
def predict_action(model, image, device): |
|
|
""" |
|
|
Predict action for an image |
|
|
|
|
|
Args: |
|
|
model: Trained action model |
|
|
image: PIL Image |
|
|
device: torch device |
|
|
|
|
|
Returns: |
|
|
dict: Prediction results with class, confidence, and all predictions |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
class_names = get_action_class_names() |
|
|
|
|
|
|
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor) |
|
|
probabilities = torch.softmax(outputs, dim=1) |
|
|
confidence, predicted_idx = probabilities.max(dim=1) |
|
|
|
|
|
predicted_class = class_names[predicted_idx.item()] |
|
|
confidence_percent = confidence.item() * 100 |
|
|
|
|
|
|
|
|
all_probs = probabilities[0].cpu().numpy() * 100 |
|
|
|
|
|
|
|
|
all_predictions = [] |
|
|
for idx, prob in enumerate(all_probs): |
|
|
all_predictions.append({ |
|
|
'class': class_names[idx], |
|
|
'probability': float(prob) |
|
|
}) |
|
|
|
|
|
|
|
|
all_predictions = sorted(all_predictions, key=lambda x: x['probability'], reverse=True) |
|
|
|
|
|
return { |
|
|
'predicted_class': predicted_class, |
|
|
'confidence': float(confidence_percent), |
|
|
'all_predictions': all_predictions[:5] |
|
|
} |