import numpy as np import gradio as gr import torch from torch import nn import torchvision from torchvision.transforms import v2 from transformers import CLIPForImageClassification, SiglipForImageClassification from common import LABELS, MODELS, new_id_to_old_id device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") transform = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.RandomResizedCrop(size=(224, 224), antialias=False), v2.RandomHorizontalFlip(p=0.3), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) current_model_name = None model = None def load_model(model_name): global current_model_name, model if model_name == current_model_name: return f"Модель '{model_name}' уже загружена" if model is not None: del model torch.cuda.empty_cache() if torch.cuda.is_available() else None model_hf_name, head_name = MODELS[model_name] if model_name == "CLIP ViT-B/32 (27 классов)": model = CLIPForImageClassification.from_pretrained(model_hf_name) model.classifier = nn.Linear(in_features=768, out_features=27) elif model_name == "SigLIP2 ViT-B/16 (27 классов)": model = SiglipForImageClassification.from_pretrained(model_hf_name) model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(in_features=768, out_features=512), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(in_features=512, out_features=27), ) elif model_name == "CLIP ViT-B/32 (14 классов)": model = CLIPForImageClassification.from_pretrained(model_hf_name) model.classifier = nn.Linear(in_features=768, out_features=14) else: model = torchvision.models.resnet50(weights='DEFAULT') model.fc = nn.Linear(in_features=2048, out_features=27) checkpoint = torch.load(head_name, map_location=torch.device('cpu')) model.load_state_dict(checkpoint) model.to(device) model.eval() current_model_name = model_name return f"Модель '{model_name}' загружена" def classify_image(image, top_k=3, model_name="CLIP ViT-B/32 (27 классов)"): global current_model_name, model if model is None or current_model_name != model_name: status = load_model(model_name) print(status) inputs = transform(image).to(device) with torch.no_grad(): outputs = model(inputs.unsqueeze(0)) if current_model_name == "ResNet50 (27 классов)": probs = torch.softmax(outputs, dim=1) else: probs = torch.softmax(outputs.logits, dim=1) probs = probs.cpu().numpy()[0] sorted_indices = np.argsort(probs)[::-1] results = {} for i in range(min(top_k, len(LABELS))): idx = sorted_indices[i] if current_model_name == "CLIP ViT-B/32 (14 классов)": label = LABELS[new_id_to_old_id[idx]] else: label = LABELS[idx] prob = float(probs[idx]) results[label] = prob return results def create_interface(): def process_image(image, top_k, model_name): results = classify_image(image, top_k=int(top_k), model_name=model_name) output = "Результаты классификации:\n\n" for label, prob in results.items(): percentage = prob * 100 bar_length = int(percentage / 5) bar = "█" * bar_length + "░" * (20 - bar_length) output += f"**{label}**: {percentage:.1f}%\n`{bar}`\n\n" return output description = """ # Определение стиля произведения искусства ## Как использовать: 1. Загрузите изображение 3. Выберите количество топ-N предсказаний 4. Нажмите "Классифицировать" """ with gr.Blocks() as demo: gr.Markdown("### Настройки модели") model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="CLIP ViT-B/32 (27 классов)", label="Выберите модель", ) load_model_btn = gr.Button("Загрузить модель", variant="secondary") model_status = gr.Textbox(label="Статус", interactive=False) def on_model_change(model_name): return load_model(model_name) model_dropdown.change( fn=on_model_change, inputs=[model_dropdown], outputs=[model_status] ) load_model_btn.click( fn=on_model_change, inputs=[model_dropdown], outputs=[model_status] ) gr.Markdown(description) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Загрузите изображение") top_k_slider = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Ton-N предсказаний" ) submit_btn = gr.Button("Классифицировать", variant="primary") with gr.Column(scale=1): output_text = gr.Markdown(label="Результаты") submit_btn.click( fn=process_image, inputs=[image_input, top_k_slider, model_dropdown], outputs=[output_text] ) gr.Markdown("### Примеры изображений для тестирования:") example_images = [ ["examples/example1.jpeg"], ["examples/example2.jpeg"], ["examples/example3.jpg"], ["examples/example4.jpg"], ["examples/example5.jpg"], ["examples/example6.jpg"], ] gr.Examples( examples=example_images, inputs=[image_input], ) return demo if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=False )