monkseal555's picture
Upload app.py with huggingface_hub
294881e verified
"""
WeatherFlow Training - Pure FastAPI App
Train flow matching models for hurricane wind field prediction
API-only approach to avoid Gradio/huggingface_hub compatibility issues
"""
import os
import sys
import json
import time
import tempfile
import shutil
import asyncio
import threading
import uuid
import base64
from io import BytesIO
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any, List
from contextlib import asynccontextmanager
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gdown
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
# ============================================================================
# Global Training State (for tracking progress)
# ============================================================================
training_jobs: Dict[str, Dict[str, Any]] = {}
# ============================================================================
# Flow Matching Model Architecture
# ============================================================================
class SinusoidalPositionEmbeddings(nn.Module):
"""Time embeddings for flow matching"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = np.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class ResidualBlock(nn.Module):
"""Residual block with time embedding"""
def __init__(self, in_channels, out_channels, time_emb_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.time_mlp = nn.Linear(time_emb_dim, out_channels)
self.norm1 = nn.GroupNorm(8, out_channels)
self.norm2 = nn.GroupNorm(8, out_channels)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
else:
self.shortcut = nn.Identity()
def forward(self, x, time_emb):
h = self.conv1(x)
h = self.norm1(h)
h = torch.relu(h)
time_emb = self.time_mlp(time_emb)
h = h + time_emb[:, :, None, None]
h = self.conv2(h)
h = self.norm2(h)
h = torch.relu(h)
return h + self.shortcut(x)
class FlowMatchingUNet(nn.Module):
"""U-Net architecture for flow matching velocity prediction"""
def __init__(self, in_channels=6, out_channels=3, hidden_dim=128, num_layers=4, use_attention=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
time_dim = hidden_dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(hidden_dim),
nn.Linear(hidden_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
self.init_conv = nn.Conv2d(in_channels, hidden_dim, 3, padding=1)
# Build channel list for encoder
channels = [hidden_dim * (2 ** i) for i in range(num_layers)]
# Encoder
self.encoder_blocks = nn.ModuleList()
self.downsample = nn.ModuleList()
prev_ch = hidden_dim
for ch in channels:
self.encoder_blocks.append(ResidualBlock(prev_ch, ch, time_dim))
self.downsample.append(nn.Conv2d(ch, ch, 4, stride=2, padding=1))
prev_ch = ch
# Middle
self.middle = ResidualBlock(channels[-1], channels[-1], time_dim)
# Decoder - need to handle skip connections properly
# After middle block, we have channels[-1] channels
# We upsample and concatenate with encoder outputs in reverse order
self.decoder_blocks = nn.ModuleList()
self.upsample = nn.ModuleList()
# Track what channels we have at each decoder stage
# Start with middle output: channels[-1]
# Encoder outputs (in order): hidden_dim, hidden_dim*2, hidden_dim*4, hidden_dim*8, ...
# We concatenate in reverse: channels[-1], channels[-2], ...
decoder_in_ch = channels[-1] # Output of middle block
for i in range(num_layers):
# Skip connection from encoder (reversed order)
skip_ch = channels[num_layers - 1 - i]
# Output channels for this decoder block
out_ch = channels[num_layers - 2 - i] if i < num_layers - 1 else hidden_dim
# Upsample from current channels
self.upsample.append(nn.ConvTranspose2d(decoder_in_ch, decoder_in_ch, 4, stride=2, padding=1))
# After concat: decoder_in_ch + skip_ch
self.decoder_blocks.append(ResidualBlock(decoder_in_ch + skip_ch, out_ch, time_dim))
# Next iteration starts with out_ch
decoder_in_ch = out_ch
self.final_conv = nn.Sequential(
nn.GroupNorm(8, hidden_dim),
nn.SiLU(),
nn.Conv2d(hidden_dim, out_channels, 3, padding=1),
)
def forward(self, x, t):
time_emb = self.time_mlp(t)
x = self.init_conv(x)
# Encoder with skip connections
encoder_outputs = []
for block, down in zip(self.encoder_blocks, self.downsample):
x = block(x, time_emb)
encoder_outputs.append(x)
x = down(x)
# Middle
x = self.middle(x, time_emb)
# Decoder with skip connections
for i, (up, block) in enumerate(zip(self.upsample, self.decoder_blocks)):
x = up(x)
# Get corresponding encoder output (reversed order)
skip = encoder_outputs[-(i + 1)]
# Handle size mismatch
if x.shape[2:] != skip.shape[2:]:
x = torch.nn.functional.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, skip], dim=1)
x = block(x, time_emb)
return self.final_conv(x)
# ============================================================================
# Dataset for Hurricane Image Pairs
# ============================================================================
class HurricaneImagePairDataset(Dataset):
"""Dataset for hurricane image pairs from Google Drive"""
def __init__(self, image_paths: List[str], image_size: int = 128):
self.image_paths = image_paths
self.image_size = image_size
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
img = Image.open(img_path).convert('RGB')
# Split image in half (left = brightness temp, right = wind field)
width, height = img.size
mid = width // 2
brightness_temp = img.crop((0, 0, mid, height))
wind_field = img.crop((mid, 0, width, height))
# Resize to target size
brightness_temp = brightness_temp.resize((self.image_size, self.image_size), Image.LANCZOS)
wind_field = wind_field.resize((self.image_size, self.image_size), Image.LANCZOS)
# Convert to tensors and normalize to [-1, 1]
bt_tensor = torch.tensor(np.array(brightness_temp)).permute(2, 0, 1).float() / 127.5 - 1
wf_tensor = torch.tensor(np.array(wind_field)).permute(2, 0, 1).float() / 127.5 - 1
return bt_tensor, wf_tensor
# ============================================================================
# Training Function
# ============================================================================
def download_gdrive_folder(gdrive_url: str, output_dir: str) -> List[str]:
"""Download all images from a Google Drive folder with robust error handling"""
os.makedirs(output_dir, exist_ok=True)
# Extract folder ID from URL
if "folders/" in gdrive_url:
folder_id = gdrive_url.split("folders/")[1].split("?")[0].split("/")[0]
elif "id=" in gdrive_url:
folder_id = gdrive_url.split("id=")[1].split("&")[0]
else:
folder_id = gdrive_url.strip("/").split("/")[-1]
print(f"Downloading from folder ID: {folder_id}")
download_errors = []
# Try downloading with gdown - it will download what it can
try:
gdown.download_folder(id=folder_id, output=output_dir, quiet=False)
except Exception as e:
error_msg = str(e)
print(f"Warning during folder download: {error_msg}")
download_errors.append(error_msg)
# If first method fails completely, try alternative
try:
gdown.download_folder(url=gdrive_url, output=output_dir, quiet=False)
except Exception as e2:
print(f"Alternative download also had issues: {e2}")
download_errors.append(str(e2))
# Find all successfully downloaded image files
image_extensions = {'.png', '.jpg', '.jpeg', '.webp', '.bmp'}
image_paths = []
for root, dirs, files in os.walk(output_dir):
for file in files:
if Path(file).suffix.lower() in image_extensions:
file_path = os.path.join(root, file)
# Verify file is valid (not corrupted/partial)
try:
file_size = os.path.getsize(file_path)
if file_size > 1000: # At least 1KB
image_paths.append(file_path)
else:
print(f"Skipping small/incomplete file: {file} ({file_size} bytes)")
except Exception as e:
print(f"Error checking file {file}: {e}")
print(f"\n{'='*50}")
print(f"Download Summary:")
print(f" Successfully downloaded: {len(image_paths)} images")
if download_errors:
print(f" Some files may have been skipped due to rate limiting")
print(f" Training will continue with available images")
print(f"{'='*50}\n")
return sorted(image_paths)
def train_flow_matching(
job_id: str,
gdrive_url: str,
hidden_dim: int = 128,
num_layers: int = 4,
use_attention: bool = True,
epochs: int = 50,
batch_size: int = 4,
learning_rate: float = 0.0005,
image_size: int = 128,
hf_token: Optional[str] = None,
experiment_name: str = "hurricane_flow_matching"
):
"""Train flow matching model on hurricane image pairs"""
global training_jobs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
training_jobs[job_id]["status"] = "downloading"
training_jobs[job_id]["device"] = str(device)
# Download images from Google Drive
with tempfile.TemporaryDirectory() as temp_dir:
try:
image_paths = download_gdrive_folder(gdrive_url, temp_dir)
except Exception as e:
training_jobs[job_id]["status"] = "failed"
training_jobs[job_id]["error"] = f"Failed to download images: {str(e)}"
return
if len(image_paths) < 4:
training_jobs[job_id]["status"] = "failed"
training_jobs[job_id]["error"] = f"Not enough images found ({len(image_paths)}). Need at least 4."
return
training_jobs[job_id]["num_images"] = len(image_paths)
# Split into train/val
np.random.shuffle(image_paths)
split_idx = int(len(image_paths) * 0.8)
train_paths = image_paths[:split_idx]
val_paths = image_paths[split_idx:]
print(f"Train: {len(train_paths)}, Val: {len(val_paths)}")
# Create datasets
train_dataset = HurricaneImagePairDataset(train_paths, image_size)
val_dataset = HurricaneImagePairDataset(val_paths, image_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
# Create model (input: brightness temp (3) + noisy wind (3) = 6, output: velocity (3))
model = FlowMatchingUNet(
in_channels=6,
out_channels=3,
hidden_dim=hidden_dim,
num_layers=num_layers,
use_attention=use_attention
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
training_jobs[job_id]["status"] = "training"
training_jobs[job_id]["total_epochs"] = epochs
history = {"epoch": [], "train_loss": [], "val_loss": []}
for epoch in range(epochs):
# Check if cancelled
if training_jobs[job_id].get("cancelled", False):
training_jobs[job_id]["status"] = "cancelled"
return
model.train()
train_losses = []
for batch_idx, (brightness_temp, wind_field) in enumerate(train_loader):
brightness_temp = brightness_temp.to(device)
wind_field = wind_field.to(device)
# Sample random time t in [0, 1]
t = torch.rand(brightness_temp.shape[0], device=device)
# Sample noise
noise = torch.randn_like(wind_field)
# Interpolate: x_t = (1-t) * noise + t * wind_field
t_expand = t[:, None, None, None]
x_t = (1 - t_expand) * noise + t_expand * wind_field
# Target velocity: v = wind_field - noise
target_v = wind_field - noise
# Concatenate brightness temp with noisy wind field
model_input = torch.cat([brightness_temp, x_t], dim=1)
# Predict velocity
pred_v = model(model_input, t)
# MSE loss
loss = torch.nn.functional.mse_loss(pred_v, target_v)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# Validation
model.eval()
val_losses = []
with torch.no_grad():
for brightness_temp, wind_field in val_loader:
brightness_temp = brightness_temp.to(device)
wind_field = wind_field.to(device)
t = torch.rand(brightness_temp.shape[0], device=device)
noise = torch.randn_like(wind_field)
t_expand = t[:, None, None, None]
x_t = (1 - t_expand) * noise + t_expand * wind_field
target_v = wind_field - noise
model_input = torch.cat([brightness_temp, x_t], dim=1)
pred_v = model(model_input, t)
loss = torch.nn.functional.mse_loss(pred_v, target_v)
val_losses.append(loss.item())
train_loss = np.mean(train_losses)
val_loss = np.mean(val_losses) if val_losses else train_loss
scheduler.step()
# Update history
history["epoch"].append(epoch + 1)
history["train_loss"].append(float(train_loss))
history["val_loss"].append(float(val_loss))
# Update job status
training_jobs[job_id]["current_epoch"] = epoch + 1
training_jobs[job_id]["train_loss"] = float(train_loss)
training_jobs[job_id]["val_loss"] = float(val_loss)
training_jobs[job_id]["progress"] = ((epoch + 1) / epochs) * 100
training_jobs[job_id]["history"] = history
print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
# Training complete
training_jobs[job_id]["status"] = "completed"
training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
# Save model if HF token provided
if hf_token:
try:
from huggingface_hub import HfApi, create_repo
api = HfApi(token=hf_token)
user_info = api.whoami()
username = user_info["name"]
model_repo = f"{username}/weatherflow-{experiment_name}-{job_id[:8]}"
create_repo(repo_id=model_repo, token=hf_token, exist_ok=True)
# Save model locally first
model_path = os.path.join(temp_dir, "model.pt")
torch.save({
"model_state_dict": model.state_dict(),
"config": {
"hidden_dim": hidden_dim,
"num_layers": num_layers,
"use_attention": use_attention,
"image_size": image_size,
},
"history": history,
}, model_path)
api.upload_file(
path_or_fileobj=model_path,
path_in_repo="model.pt",
repo_id=model_repo,
token=hf_token
)
training_jobs[job_id]["model_repo"] = model_repo
print(f"Model saved to: https://huggingface.co/{model_repo}")
except Exception as e:
print(f"Error saving model to HF: {e}")
training_jobs[job_id]["model_save_error"] = str(e)
# Run inference on test set after training completes
print("Running inference on test set...")
training_jobs[job_id]["status"] = "running_inference"
try:
test_results = []
model.eval()
# Use validation set as test set (or a subset)
test_loader = val_loader if val_loader else train_loader
num_test_samples = min(10, len(test_loader.dataset)) # Limit to 10 samples
with torch.no_grad():
sample_count = 0
for brightness_temp, wind_field in test_loader:
if sample_count >= num_test_samples:
break
brightness_temp = brightness_temp.to(device)
wind_field = wind_field.to(device)
# Run flow matching inference
batch_size = brightness_temp.shape[0]
x = torch.randn(batch_size, 3, image_size, image_size).to(device)
num_steps = 20 # Faster inference for test set
dt = 1.0 / num_steps
for step in range(num_steps):
t = torch.tensor([step * dt] * batch_size).to(device)
model_input = torch.cat([brightness_temp, x], dim=1)
v = model(model_input, t)
x = x + v * dt
# Convert to images (base64)
for i in range(min(batch_size, num_test_samples - sample_count)):
# Input image
input_img = brightness_temp[i].cpu().permute(1, 2, 0).numpy()
input_img = np.clip(input_img * 255, 0, 255).astype(np.uint8)
input_pil = Image.fromarray(input_img)
input_buffer = BytesIO()
input_pil.save(input_buffer, format='PNG')
input_b64 = base64.b64encode(input_buffer.getvalue()).decode('utf-8')
# Ground truth
gt_img = wind_field[i].cpu().permute(1, 2, 0).numpy()
gt_img = np.clip(gt_img * 255, 0, 255).astype(np.uint8)
gt_pil = Image.fromarray(gt_img)
gt_buffer = BytesIO()
gt_pil.save(gt_buffer, format='PNG')
gt_b64 = base64.b64encode(gt_buffer.getvalue()).decode('utf-8')
# Predicted output
pred_img = x[i].cpu().permute(1, 2, 0).numpy()
pred_img = np.clip(pred_img * 255, 0, 255).astype(np.uint8)
pred_pil = Image.fromarray(pred_img)
pred_buffer = BytesIO()
pred_pil.save(pred_buffer, format='PNG')
pred_b64 = base64.b64encode(pred_buffer.getvalue()).decode('utf-8')
test_results.append({
"input": f"data:image/png;base64,{input_b64}",
"ground_truth": f"data:image/png;base64,{gt_b64}",
"prediction": f"data:image/png;base64,{pred_b64}",
"index": sample_count + i
})
sample_count += batch_size
training_jobs[job_id]["test_results"] = test_results
print(f"Generated {len(test_results)} test set predictions")
except Exception as e:
print(f"Error running test set inference: {e}")
training_jobs[job_id]["test_inference_error"] = str(e)
# Final status update
training_jobs[job_id]["status"] = "completed"
# ============================================================================
# FastAPI App
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
print("WeatherFlow Training API starting...")
yield
print("WeatherFlow Training API shutting down...")
app = FastAPI(
title="WeatherFlow Training API",
description="Train flow matching models for hurricane wind field prediction",
version="1.0.0",
lifespan=lifespan
)
# Enable CORS for all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class TrainingRequest(BaseModel):
gdrive_url: str
experiment_name: str = "hurricane_flow_matching"
hidden_dim: int = 128
num_layers: int = 4
use_attention: bool = True
epochs: int = 50
batch_size: int = 4
learning_rate: float = 0.0005
image_size: int = 128
hf_token: Optional[str] = None
class TrainingResponse(BaseModel):
job_id: str
status: str
message: str
@app.get("/", response_class=HTMLResponse)
async def root():
"""Simple HTML interface"""
return """
<!DOCTYPE html>
<html>
<head>
<title>WeatherFlow Training API</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
h1 { color: #333; }
.endpoint { background: #f5f5f5; padding: 15px; margin: 10px 0; border-radius: 5px; }
code { background: #e0e0e0; padding: 2px 6px; border-radius: 3px; }
</style>
</head>
<body>
<h1>🌀 WeatherFlow Training API</h1>
<p>Train flow matching models for hurricane wind field prediction.</p>
<h2>API Endpoints</h2>
<div class="endpoint">
<h3>POST /api/train</h3>
<p>Start a new training job</p>
<p>Body: <code>{"gdrive_url": "...", "epochs": 50, ...}</code></p>
</div>
<div class="endpoint">
<h3>GET /api/status/{job_id}</h3>
<p>Get training status and progress</p>
</div>
<div class="endpoint">
<h3>POST /api/cancel/{job_id}</h3>
<p>Cancel a running training job</p>
</div>
<div class="endpoint">
<h3>GET /api/jobs</h3>
<p>List all training jobs</p>
</div>
<div class="endpoint">
<h3>GET /health</h3>
<p>Health check endpoint</p>
</div>
<p><a href="/docs">📚 Interactive API Documentation (Swagger UI)</a></p>
</body>
</html>
"""
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"cuda_available": torch.cuda.is_available(),
"device": "cuda" if torch.cuda.is_available() else "cpu",
"active_jobs": len([j for j in training_jobs.values() if j.get("status") == "training"])
}
@app.post("/api/train", response_model=TrainingResponse)
async def start_training(request: TrainingRequest):
"""Start a new training job"""
job_id = str(uuid.uuid4())
training_jobs[job_id] = {
"status": "queued",
"created_at": datetime.now().isoformat(),
"config": request.dict(),
"current_epoch": 0,
"total_epochs": request.epochs,
"train_loss": None,
"val_loss": None,
"progress": 0,
"history": {"epoch": [], "train_loss": [], "val_loss": []},
}
# Start training in a separate thread (not BackgroundTasks which blocks on CPU-intensive work)
def run_training():
try:
train_flow_matching(
job_id=job_id,
gdrive_url=request.gdrive_url,
hidden_dim=request.hidden_dim,
num_layers=request.num_layers,
use_attention=request.use_attention,
epochs=request.epochs,
batch_size=request.batch_size,
learning_rate=request.learning_rate,
image_size=request.image_size,
hf_token=request.hf_token,
experiment_name=request.experiment_name
)
except Exception as e:
print(f"Training error: {e}")
import traceback
traceback.print_exc()
training_jobs[job_id]["status"] = "failed"
training_jobs[job_id]["error"] = str(e)
thread = threading.Thread(target=run_training, daemon=True)
thread.start()
return TrainingResponse(
job_id=job_id,
status="queued",
message="Training job started in background thread"
)
@app.get("/api/status/{job_id}")
async def get_training_status(job_id: str):
"""Get status of a training job"""
if job_id not in training_jobs:
raise HTTPException(status_code=404, detail="Job not found")
return training_jobs[job_id]
@app.post("/api/cancel/{job_id}")
async def cancel_training(job_id: str):
"""Cancel a training job"""
if job_id not in training_jobs:
raise HTTPException(status_code=404, detail="Job not found")
training_jobs[job_id]["cancelled"] = True
return {"status": "cancelling", "message": "Cancellation requested"}
@app.get("/api/jobs")
async def list_jobs():
"""List all training jobs"""
return {
"jobs": [
{"job_id": job_id, **{k: v for k, v in job.items() if k != "history"}}
for job_id, job in training_jobs.items()
]
}
class InferenceRequest(BaseModel):
model_repo: str # HuggingFace model repo ID (e.g., "monkseal555/hurricane-flow-model")
image_url: str # URL to the input image
hf_token: Optional[str] = None # HuggingFace token for private repos
num_steps: int = 50 # Number of flow matching steps
hidden_dim: int = 128
num_layers: int = 4
@app.post("/api/inference")
async def run_inference(request: InferenceRequest):
"""Run inference using a trained model from HuggingFace"""
import requests
from io import BytesIO
from PIL import Image
import base64
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download model from HuggingFace
from huggingface_hub import hf_hub_download
# Use provided token or try without
token = request.hf_token if request.hf_token else None
try:
model_path = hf_hub_download(
repo_id=request.model_repo,
filename="model.pt",
token=token
)
except Exception as e:
# Try alternative filename
try:
model_path = hf_hub_download(
repo_id=request.model_repo,
filename="pytorch_model.bin",
token=token
)
except Exception as e2:
raise HTTPException(status_code=404, detail=f"Could not download model: {str(e)}. Make sure the model repo exists and you have access. Original error: {str(e2)}")
# Load model
model = FlowMatchingUNet(
in_channels=6,
out_channels=3,
hidden_dim=request.hidden_dim,
num_layers=request.num_layers
).to(device)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
# Download and process input image
response = requests.get(request.image_url, timeout=30)
response.raise_for_status()
img = Image.open(BytesIO(response.content)).convert('RGB')
img = img.resize((128, 128))
# Convert to tensor
img_tensor = torch.from_numpy(np.array(img)).float() / 255.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) # [1, 3, 128, 128]
# Run flow matching inference
with torch.no_grad():
# Start from noise
x = torch.randn(1, 3, 128, 128).to(device)
# Flow matching: integrate from t=0 to t=1
dt = 1.0 / request.num_steps
for i in range(request.num_steps):
t = torch.tensor([i * dt]).to(device)
# Concatenate input image with current state
model_input = torch.cat([img_tensor, x], dim=1) # [1, 6, 128, 128]
# Predict velocity
v = model(model_input, t)
# Euler step
x = x + v * dt
# Convert output to image
output = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
output = np.clip(output * 255, 0, 255).astype(np.uint8)
output_img = Image.fromarray(output)
# Encode as base64
buffer = BytesIO()
output_img.save(buffer, format='PNG')
output_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return {
"status": "success",
"output_image": f"data:image/png;base64,{output_base64}",
"model_repo": request.model_repo,
"num_steps": request.num_steps,
"device": str(device)
}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/models")
async def list_available_models():
"""List models that have been trained and saved"""
models = []
for job_id, job in training_jobs.items():
if job.get("status") == "completed" and job.get("model_url"):
models.append({
"job_id": job_id,
"model_url": job.get("model_url"),
"experiment_name": job.get("config", {}).get("experiment_name"),
"completed_at": job.get("completed_at")
})
return {"models": models}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)