HSE_DL_Project / main.py
misshimichka's picture
Upload folder using huggingface_hub
54061a0 verified
import numpy as np
import gradio as gr
import torch
from torch import nn
import torchvision
from torchvision.transforms import v2
from transformers import CLIPForImageClassification, SiglipForImageClassification
from common import LABELS, MODELS, new_id_to_old_id
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
transform = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.RandomResizedCrop(size=(224, 224), antialias=False),
v2.RandomHorizontalFlip(p=0.3),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
current_model_name = None
model = None
def load_model(model_name):
global current_model_name, model
if model_name == current_model_name:
return f"Модель '{model_name}' уже загружена"
if model is not None:
del model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
model_hf_name, head_name = MODELS[model_name]
if model_name == "CLIP ViT-B/32 (27 классов)":
model = CLIPForImageClassification.from_pretrained(model_hf_name)
model.classifier = nn.Linear(in_features=768, out_features=27)
elif model_name == "SigLIP2 ViT-B/16 (27 классов)":
model = SiglipForImageClassification.from_pretrained(model_hf_name)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(in_features=768, out_features=512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(in_features=512, out_features=27),
)
elif model_name == "CLIP ViT-B/32 (14 классов)":
model = CLIPForImageClassification.from_pretrained(model_hf_name)
model.classifier = nn.Linear(in_features=768, out_features=14)
else:
model = torchvision.models.resnet50(weights='DEFAULT')
model.fc = nn.Linear(in_features=2048, out_features=27)
checkpoint = torch.load(head_name, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
current_model_name = model_name
return f"Модель '{model_name}' загружена"
def classify_image(image, top_k=3, model_name="CLIP ViT-B/32 (27 классов)"):
global current_model_name, model
if model is None or current_model_name != model_name:
status = load_model(model_name)
print(status)
inputs = transform(image).to(device)
with torch.no_grad():
outputs = model(inputs.unsqueeze(0))
if current_model_name == "ResNet50 (27 классов)":
probs = torch.softmax(outputs, dim=1)
else:
probs = torch.softmax(outputs.logits, dim=1)
probs = probs.cpu().numpy()[0]
sorted_indices = np.argsort(probs)[::-1]
results = {}
for i in range(min(top_k, len(LABELS))):
idx = sorted_indices[i]
if current_model_name == "CLIP ViT-B/32 (14 классов)":
label = LABELS[new_id_to_old_id[idx]]
else:
label = LABELS[idx]
prob = float(probs[idx])
results[label] = prob
return results
def create_interface():
def process_image(image, top_k, model_name):
results = classify_image(image, top_k=int(top_k), model_name=model_name)
output = "Результаты классификации:\n\n"
for label, prob in results.items():
percentage = prob * 100
bar_length = int(percentage / 5)
bar = "█" * bar_length + "░" * (20 - bar_length)
output += f"**{label}**: {percentage:.1f}%\n`{bar}`\n\n"
return output
description = """
# Определение стиля произведения искусства
## Как использовать:
1. Загрузите изображение
3. Выберите количество топ-N предсказаний
4. Нажмите "Классифицировать"
"""
with gr.Blocks() as demo:
gr.Markdown("### Настройки модели")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="CLIP ViT-B/32 (27 классов)",
label="Выберите модель",
)
load_model_btn = gr.Button("Загрузить модель", variant="secondary")
model_status = gr.Textbox(label="Статус", interactive=False)
def on_model_change(model_name):
return load_model(model_name)
model_dropdown.change(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[model_status]
)
load_model_btn.click(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[model_status]
)
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Загрузите изображение")
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Ton-N предсказаний"
)
submit_btn = gr.Button("Классифицировать", variant="primary")
with gr.Column(scale=1):
output_text = gr.Markdown(label="Результаты")
submit_btn.click(
fn=process_image,
inputs=[image_input, top_k_slider, model_dropdown],
outputs=[output_text]
)
gr.Markdown("### Примеры изображений для тестирования:")
example_images = [
["examples/example1.jpeg"],
["examples/example2.jpeg"],
["examples/example3.jpg"],
["examples/example4.jpg"],
["examples/example5.jpg"],
["examples/example6.jpg"],
]
gr.Examples(
examples=example_images,
inputs=[image_input],
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False
)