Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import torchvision | |
| from torchvision.transforms import v2 | |
| from transformers import CLIPForImageClassification, SiglipForImageClassification | |
| from common import LABELS, MODELS, new_id_to_old_id | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| transform = v2.Compose([ | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.RandomResizedCrop(size=(224, 224), antialias=False), | |
| v2.RandomHorizontalFlip(p=0.3), | |
| v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| current_model_name = None | |
| model = None | |
| def load_model(model_name): | |
| global current_model_name, model | |
| if model_name == current_model_name: | |
| return f"Модель '{model_name}' уже загружена" | |
| if model is not None: | |
| del model | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| model_hf_name, head_name = MODELS[model_name] | |
| if model_name == "CLIP ViT-B/32 (27 классов)": | |
| model = CLIPForImageClassification.from_pretrained(model_hf_name) | |
| model.classifier = nn.Linear(in_features=768, out_features=27) | |
| elif model_name == "SigLIP2 ViT-B/16 (27 классов)": | |
| model = SiglipForImageClassification.from_pretrained(model_hf_name) | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3), | |
| nn.Linear(in_features=768, out_features=512), | |
| nn.ReLU(), | |
| nn.Dropout(p=0.3), | |
| nn.Linear(in_features=512, out_features=27), | |
| ) | |
| elif model_name == "CLIP ViT-B/32 (14 классов)": | |
| model = CLIPForImageClassification.from_pretrained(model_hf_name) | |
| model.classifier = nn.Linear(in_features=768, out_features=14) | |
| else: | |
| model = torchvision.models.resnet50(weights='DEFAULT') | |
| model.fc = nn.Linear(in_features=2048, out_features=27) | |
| checkpoint = torch.load(head_name, map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| current_model_name = model_name | |
| return f"Модель '{model_name}' загружена" | |
| def classify_image(image, top_k=3, model_name="CLIP ViT-B/32 (27 классов)"): | |
| global current_model_name, model | |
| if model is None or current_model_name != model_name: | |
| status = load_model(model_name) | |
| print(status) | |
| inputs = transform(image).to(device) | |
| with torch.no_grad(): | |
| outputs = model(inputs.unsqueeze(0)) | |
| if current_model_name == "ResNet50 (27 классов)": | |
| probs = torch.softmax(outputs, dim=1) | |
| else: | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| probs = probs.cpu().numpy()[0] | |
| sorted_indices = np.argsort(probs)[::-1] | |
| results = {} | |
| for i in range(min(top_k, len(LABELS))): | |
| idx = sorted_indices[i] | |
| if current_model_name == "CLIP ViT-B/32 (14 классов)": | |
| label = LABELS[new_id_to_old_id[idx]] | |
| else: | |
| label = LABELS[idx] | |
| prob = float(probs[idx]) | |
| results[label] = prob | |
| return results | |
| def create_interface(): | |
| def process_image(image, top_k, model_name): | |
| results = classify_image(image, top_k=int(top_k), model_name=model_name) | |
| output = "Результаты классификации:\n\n" | |
| for label, prob in results.items(): | |
| percentage = prob * 100 | |
| bar_length = int(percentage / 5) | |
| bar = "█" * bar_length + "░" * (20 - bar_length) | |
| output += f"**{label}**: {percentage:.1f}%\n`{bar}`\n\n" | |
| return output | |
| description = """ | |
| # Определение стиля произведения искусства | |
| ## Как использовать: | |
| 1. Загрузите изображение | |
| 3. Выберите количество топ-N предсказаний | |
| 4. Нажмите "Классифицировать" | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Настройки модели") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="CLIP ViT-B/32 (27 классов)", | |
| label="Выберите модель", | |
| ) | |
| load_model_btn = gr.Button("Загрузить модель", variant="secondary") | |
| model_status = gr.Textbox(label="Статус", interactive=False) | |
| def on_model_change(model_name): | |
| return load_model(model_name) | |
| model_dropdown.change( | |
| fn=on_model_change, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| load_model_btn.click( | |
| fn=on_model_change, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Загрузите изображение") | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Ton-N предсказаний" | |
| ) | |
| submit_btn = gr.Button("Классифицировать", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Markdown(label="Результаты") | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, top_k_slider, model_dropdown], | |
| outputs=[output_text] | |
| ) | |
| gr.Markdown("### Примеры изображений для тестирования:") | |
| example_images = [ | |
| ["examples/example1.jpeg"], | |
| ["examples/example2.jpeg"], | |
| ["examples/example3.jpg"], | |
| ["examples/example4.jpg"], | |
| ["examples/example5.jpg"], | |
| ["examples/example6.jpg"], | |
| ] | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=[image_input], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False | |
| ) |