Test / app.py
AndersonConforto's picture
first commit
c4ccf03
raw
history blame
2.34 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import requests
import os
# ================================
# 1. Baixar pesos do Surya-1.0
# ================================
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
MODEL_FILE = "surya.366m.v1.pt"
def download_model():
if not os.path.exists(MODEL_FILE):
print("Baixando pesos do Surya-1.0...")
r = requests.get(MODEL_URL)
with open(MODEL_FILE, "wb") as f:
f.write(r.content)
print("Download concluído!")
download_model()
# ================================
# 2. Definir a arquitetura do Surya
# ================================
# Aqui você deve colar ou importar a classe SuryaModel do repo oficial
# Exemplo genérico:
import torch.nn as nn
class SuryaModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, kernel_size=3, padding=1)
def forward(self, x):
return self.conv(x)
# ================================
# 3. Criar instância e carregar pesos
# ================================
model = SuryaModel()
state_dict = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
# ================================
# 4. Função de inferência para heatmap
# ================================
def infer_solar_image_heatmap(img):
img = img.convert("L").resize((224, 224))
img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
outputs = model(img_tensor)
emb = outputs.squeeze().numpy()
heatmap = emb - emb.min()
heatmap /= heatmap.max() + 1e-8
plt.imshow(heatmap, cmap='hot')
plt.axis('off')
plt.tight_layout()
return plt.gcf()
# ================================
# 5. Interface Gradio
# ================================
interface = gr.Interface(
fn=infer_solar_image_heatmap,
inputs=gr.Image(type="pil"),
outputs=gr.Plot(label="Heatmap do embedding Surya"),
title="Playground Surya-1.0 com Heatmap",
description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)
interface.launch()