| """ |
| WeatherFlow Training - Pure FastAPI App |
| Train flow matching models for hurricane wind field prediction |
| API-only approach to avoid Gradio/huggingface_hub compatibility issues |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import tempfile |
| import shutil |
| import asyncio |
| import threading |
| import uuid |
| import base64 |
| from io import BytesIO |
| from pathlib import Path |
| from datetime import datetime |
| from typing import Optional, Dict, Any, List |
| from contextlib import asynccontextmanager |
|
|
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| import gdown |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import HTMLResponse |
| from pydantic import BaseModel |
|
|
| |
| |
| |
|
|
| training_jobs: Dict[str, Dict[str, Any]] = {} |
|
|
| |
| |
| |
|
|
| class SinusoidalPositionEmbeddings(nn.Module): |
| """Time embeddings for flow matching""" |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, time): |
| device = time.device |
| half_dim = self.dim // 2 |
| embeddings = np.log(10000) / (half_dim - 1) |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
| embeddings = time[:, None] * embeddings[None, :] |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
| return embeddings |
|
|
|
|
| class ResidualBlock(nn.Module): |
| """Residual block with time embedding""" |
| def __init__(self, in_channels, out_channels, time_emb_dim): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) |
| self.time_mlp = nn.Linear(time_emb_dim, out_channels) |
| self.norm1 = nn.GroupNorm(8, out_channels) |
| self.norm2 = nn.GroupNorm(8, out_channels) |
| |
| if in_channels != out_channels: |
| self.shortcut = nn.Conv2d(in_channels, out_channels, 1) |
| else: |
| self.shortcut = nn.Identity() |
|
|
| def forward(self, x, time_emb): |
| h = self.conv1(x) |
| h = self.norm1(h) |
| h = torch.relu(h) |
| |
| time_emb = self.time_mlp(time_emb) |
| h = h + time_emb[:, :, None, None] |
| |
| h = self.conv2(h) |
| h = self.norm2(h) |
| h = torch.relu(h) |
| |
| return h + self.shortcut(x) |
|
|
|
|
| class FlowMatchingUNet(nn.Module): |
| """U-Net architecture for flow matching velocity prediction""" |
| def __init__(self, in_channels=6, out_channels=3, hidden_dim=128, num_layers=4, use_attention=True): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| |
| time_dim = hidden_dim * 4 |
| self.time_mlp = nn.Sequential( |
| SinusoidalPositionEmbeddings(hidden_dim), |
| nn.Linear(hidden_dim, time_dim), |
| nn.GELU(), |
| nn.Linear(time_dim, time_dim), |
| ) |
| |
| self.init_conv = nn.Conv2d(in_channels, hidden_dim, 3, padding=1) |
| |
| |
| channels = [hidden_dim * (2 ** i) for i in range(num_layers)] |
| |
| |
| self.encoder_blocks = nn.ModuleList() |
| self.downsample = nn.ModuleList() |
| |
| prev_ch = hidden_dim |
| for ch in channels: |
| self.encoder_blocks.append(ResidualBlock(prev_ch, ch, time_dim)) |
| self.downsample.append(nn.Conv2d(ch, ch, 4, stride=2, padding=1)) |
| prev_ch = ch |
| |
| |
| self.middle = ResidualBlock(channels[-1], channels[-1], time_dim) |
| |
| |
| |
| |
| self.decoder_blocks = nn.ModuleList() |
| self.upsample = nn.ModuleList() |
| |
| |
| |
| |
| |
| |
| decoder_in_ch = channels[-1] |
| for i in range(num_layers): |
| |
| skip_ch = channels[num_layers - 1 - i] |
| |
| out_ch = channels[num_layers - 2 - i] if i < num_layers - 1 else hidden_dim |
| |
| |
| self.upsample.append(nn.ConvTranspose2d(decoder_in_ch, decoder_in_ch, 4, stride=2, padding=1)) |
| |
| self.decoder_blocks.append(ResidualBlock(decoder_in_ch + skip_ch, out_ch, time_dim)) |
| |
| decoder_in_ch = out_ch |
| |
| self.final_conv = nn.Sequential( |
| nn.GroupNorm(8, hidden_dim), |
| nn.SiLU(), |
| nn.Conv2d(hidden_dim, out_channels, 3, padding=1), |
| ) |
|
|
| def forward(self, x, t): |
| time_emb = self.time_mlp(t) |
| |
| x = self.init_conv(x) |
| |
| |
| encoder_outputs = [] |
| for block, down in zip(self.encoder_blocks, self.downsample): |
| x = block(x, time_emb) |
| encoder_outputs.append(x) |
| x = down(x) |
| |
| |
| x = self.middle(x, time_emb) |
| |
| |
| for i, (up, block) in enumerate(zip(self.upsample, self.decoder_blocks)): |
| x = up(x) |
| |
| skip = encoder_outputs[-(i + 1)] |
| |
| if x.shape[2:] != skip.shape[2:]: |
| x = torch.nn.functional.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) |
| x = torch.cat([x, skip], dim=1) |
| x = block(x, time_emb) |
| |
| return self.final_conv(x) |
|
|
|
|
| |
| |
| |
|
|
| class HurricaneImagePairDataset(Dataset): |
| """Dataset for hurricane image pairs from Google Drive""" |
| def __init__(self, image_paths: List[str], image_size: int = 128): |
| self.image_paths = image_paths |
| self.image_size = image_size |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| img_path = self.image_paths[idx] |
| img = Image.open(img_path).convert('RGB') |
| |
| |
| width, height = img.size |
| mid = width // 2 |
| |
| brightness_temp = img.crop((0, 0, mid, height)) |
| wind_field = img.crop((mid, 0, width, height)) |
| |
| |
| brightness_temp = brightness_temp.resize((self.image_size, self.image_size), Image.LANCZOS) |
| wind_field = wind_field.resize((self.image_size, self.image_size), Image.LANCZOS) |
| |
| |
| bt_tensor = torch.tensor(np.array(brightness_temp)).permute(2, 0, 1).float() / 127.5 - 1 |
| wf_tensor = torch.tensor(np.array(wind_field)).permute(2, 0, 1).float() / 127.5 - 1 |
| |
| return bt_tensor, wf_tensor |
|
|
|
|
| |
| |
| |
|
|
| def download_gdrive_folder(gdrive_url: str, output_dir: str) -> List[str]: |
| """Download all images from a Google Drive folder with robust error handling""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| if "folders/" in gdrive_url: |
| folder_id = gdrive_url.split("folders/")[1].split("?")[0].split("/")[0] |
| elif "id=" in gdrive_url: |
| folder_id = gdrive_url.split("id=")[1].split("&")[0] |
| else: |
| folder_id = gdrive_url.strip("/").split("/")[-1] |
| |
| print(f"Downloading from folder ID: {folder_id}") |
| |
| download_errors = [] |
| |
| |
| try: |
| gdown.download_folder(id=folder_id, output=output_dir, quiet=False) |
| except Exception as e: |
| error_msg = str(e) |
| print(f"Warning during folder download: {error_msg}") |
| download_errors.append(error_msg) |
| |
| |
| try: |
| gdown.download_folder(url=gdrive_url, output=output_dir, quiet=False) |
| except Exception as e2: |
| print(f"Alternative download also had issues: {e2}") |
| download_errors.append(str(e2)) |
| |
| |
| image_extensions = {'.png', '.jpg', '.jpeg', '.webp', '.bmp'} |
| image_paths = [] |
| |
| for root, dirs, files in os.walk(output_dir): |
| for file in files: |
| if Path(file).suffix.lower() in image_extensions: |
| file_path = os.path.join(root, file) |
| |
| try: |
| file_size = os.path.getsize(file_path) |
| if file_size > 1000: |
| image_paths.append(file_path) |
| else: |
| print(f"Skipping small/incomplete file: {file} ({file_size} bytes)") |
| except Exception as e: |
| print(f"Error checking file {file}: {e}") |
| |
| print(f"\n{'='*50}") |
| print(f"Download Summary:") |
| print(f" Successfully downloaded: {len(image_paths)} images") |
| if download_errors: |
| print(f" Some files may have been skipped due to rate limiting") |
| print(f" Training will continue with available images") |
| print(f"{'='*50}\n") |
| |
| return sorted(image_paths) |
|
|
|
|
| def train_flow_matching( |
| job_id: str, |
| gdrive_url: str, |
| hidden_dim: int = 128, |
| num_layers: int = 4, |
| use_attention: bool = True, |
| epochs: int = 50, |
| batch_size: int = 4, |
| learning_rate: float = 0.0005, |
| image_size: int = 128, |
| hf_token: Optional[str] = None, |
| experiment_name: str = "hurricane_flow_matching" |
| ): |
| """Train flow matching model on hurricane image pairs""" |
| global training_jobs |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| training_jobs[job_id]["status"] = "downloading" |
| training_jobs[job_id]["device"] = str(device) |
| |
| |
| with tempfile.TemporaryDirectory() as temp_dir: |
| try: |
| image_paths = download_gdrive_folder(gdrive_url, temp_dir) |
| except Exception as e: |
| training_jobs[job_id]["status"] = "failed" |
| training_jobs[job_id]["error"] = f"Failed to download images: {str(e)}" |
| return |
| |
| if len(image_paths) < 4: |
| training_jobs[job_id]["status"] = "failed" |
| training_jobs[job_id]["error"] = f"Not enough images found ({len(image_paths)}). Need at least 4." |
| return |
| |
| training_jobs[job_id]["num_images"] = len(image_paths) |
| |
| |
| np.random.shuffle(image_paths) |
| split_idx = int(len(image_paths) * 0.8) |
| train_paths = image_paths[:split_idx] |
| val_paths = image_paths[split_idx:] |
| |
| print(f"Train: {len(train_paths)}, Val: {len(val_paths)}") |
| |
| |
| train_dataset = HurricaneImagePairDataset(train_paths, image_size) |
| val_dataset = HurricaneImagePairDataset(val_paths, image_size) |
| |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) |
| |
| |
| model = FlowMatchingUNet( |
| in_channels=6, |
| out_channels=3, |
| hidden_dim=hidden_dim, |
| num_layers=num_layers, |
| use_attention=use_attention |
| ).to(device) |
| |
| optimizer = optim.AdamW(model.parameters(), lr=learning_rate) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| |
| training_jobs[job_id]["status"] = "training" |
| training_jobs[job_id]["total_epochs"] = epochs |
| |
| history = {"epoch": [], "train_loss": [], "val_loss": []} |
| |
| for epoch in range(epochs): |
| |
| if training_jobs[job_id].get("cancelled", False): |
| training_jobs[job_id]["status"] = "cancelled" |
| return |
| |
| model.train() |
| train_losses = [] |
| |
| for batch_idx, (brightness_temp, wind_field) in enumerate(train_loader): |
| brightness_temp = brightness_temp.to(device) |
| wind_field = wind_field.to(device) |
| |
| |
| t = torch.rand(brightness_temp.shape[0], device=device) |
| |
| |
| noise = torch.randn_like(wind_field) |
| |
| |
| t_expand = t[:, None, None, None] |
| x_t = (1 - t_expand) * noise + t_expand * wind_field |
| |
| |
| target_v = wind_field - noise |
| |
| |
| model_input = torch.cat([brightness_temp, x_t], dim=1) |
| |
| |
| pred_v = model(model_input, t) |
| |
| |
| loss = torch.nn.functional.mse_loss(pred_v, target_v) |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| train_losses.append(loss.item()) |
| |
| |
| model.eval() |
| val_losses = [] |
| |
| with torch.no_grad(): |
| for brightness_temp, wind_field in val_loader: |
| brightness_temp = brightness_temp.to(device) |
| wind_field = wind_field.to(device) |
| |
| t = torch.rand(brightness_temp.shape[0], device=device) |
| noise = torch.randn_like(wind_field) |
| t_expand = t[:, None, None, None] |
| x_t = (1 - t_expand) * noise + t_expand * wind_field |
| target_v = wind_field - noise |
| |
| model_input = torch.cat([brightness_temp, x_t], dim=1) |
| pred_v = model(model_input, t) |
| |
| loss = torch.nn.functional.mse_loss(pred_v, target_v) |
| val_losses.append(loss.item()) |
| |
| train_loss = np.mean(train_losses) |
| val_loss = np.mean(val_losses) if val_losses else train_loss |
| |
| scheduler.step() |
| |
| |
| history["epoch"].append(epoch + 1) |
| history["train_loss"].append(float(train_loss)) |
| history["val_loss"].append(float(val_loss)) |
| |
| |
| training_jobs[job_id]["current_epoch"] = epoch + 1 |
| training_jobs[job_id]["train_loss"] = float(train_loss) |
| training_jobs[job_id]["val_loss"] = float(val_loss) |
| training_jobs[job_id]["progress"] = ((epoch + 1) / epochs) * 100 |
| training_jobs[job_id]["history"] = history |
| |
| print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") |
| |
| |
| training_jobs[job_id]["status"] = "completed" |
| training_jobs[job_id]["completed_at"] = datetime.now().isoformat() |
| |
| |
| if hf_token: |
| try: |
| from huggingface_hub import HfApi, create_repo |
| |
| api = HfApi(token=hf_token) |
| user_info = api.whoami() |
| username = user_info["name"] |
| |
| model_repo = f"{username}/weatherflow-{experiment_name}-{job_id[:8]}" |
| |
| create_repo(repo_id=model_repo, token=hf_token, exist_ok=True) |
| |
| |
| model_path = os.path.join(temp_dir, "model.pt") |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "config": { |
| "hidden_dim": hidden_dim, |
| "num_layers": num_layers, |
| "use_attention": use_attention, |
| "image_size": image_size, |
| }, |
| "history": history, |
| }, model_path) |
| |
| api.upload_file( |
| path_or_fileobj=model_path, |
| path_in_repo="model.pt", |
| repo_id=model_repo, |
| token=hf_token |
| ) |
| |
| training_jobs[job_id]["model_repo"] = model_repo |
| print(f"Model saved to: https://huggingface.co/{model_repo}") |
| |
| except Exception as e: |
| print(f"Error saving model to HF: {e}") |
| training_jobs[job_id]["model_save_error"] = str(e) |
| |
| |
| print("Running inference on test set...") |
| training_jobs[job_id]["status"] = "running_inference" |
| |
| try: |
| test_results = [] |
| model.eval() |
| |
| |
| test_loader = val_loader if val_loader else train_loader |
| num_test_samples = min(10, len(test_loader.dataset)) |
| |
| with torch.no_grad(): |
| sample_count = 0 |
| for brightness_temp, wind_field in test_loader: |
| if sample_count >= num_test_samples: |
| break |
| |
| brightness_temp = brightness_temp.to(device) |
| wind_field = wind_field.to(device) |
| |
| |
| batch_size = brightness_temp.shape[0] |
| x = torch.randn(batch_size, 3, image_size, image_size).to(device) |
| |
| num_steps = 20 |
| dt = 1.0 / num_steps |
| |
| for step in range(num_steps): |
| t = torch.tensor([step * dt] * batch_size).to(device) |
| model_input = torch.cat([brightness_temp, x], dim=1) |
| v = model(model_input, t) |
| x = x + v * dt |
| |
| |
| for i in range(min(batch_size, num_test_samples - sample_count)): |
| |
| input_img = brightness_temp[i].cpu().permute(1, 2, 0).numpy() |
| input_img = np.clip(input_img * 255, 0, 255).astype(np.uint8) |
| input_pil = Image.fromarray(input_img) |
| input_buffer = BytesIO() |
| input_pil.save(input_buffer, format='PNG') |
| input_b64 = base64.b64encode(input_buffer.getvalue()).decode('utf-8') |
| |
| |
| gt_img = wind_field[i].cpu().permute(1, 2, 0).numpy() |
| gt_img = np.clip(gt_img * 255, 0, 255).astype(np.uint8) |
| gt_pil = Image.fromarray(gt_img) |
| gt_buffer = BytesIO() |
| gt_pil.save(gt_buffer, format='PNG') |
| gt_b64 = base64.b64encode(gt_buffer.getvalue()).decode('utf-8') |
| |
| |
| pred_img = x[i].cpu().permute(1, 2, 0).numpy() |
| pred_img = np.clip(pred_img * 255, 0, 255).astype(np.uint8) |
| pred_pil = Image.fromarray(pred_img) |
| pred_buffer = BytesIO() |
| pred_pil.save(pred_buffer, format='PNG') |
| pred_b64 = base64.b64encode(pred_buffer.getvalue()).decode('utf-8') |
| |
| test_results.append({ |
| "input": f"data:image/png;base64,{input_b64}", |
| "ground_truth": f"data:image/png;base64,{gt_b64}", |
| "prediction": f"data:image/png;base64,{pred_b64}", |
| "index": sample_count + i |
| }) |
| |
| sample_count += batch_size |
| |
| training_jobs[job_id]["test_results"] = test_results |
| print(f"Generated {len(test_results)} test set predictions") |
| |
| except Exception as e: |
| print(f"Error running test set inference: {e}") |
| training_jobs[job_id]["test_inference_error"] = str(e) |
| |
| |
| training_jobs[job_id]["status"] = "completed" |
|
|
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| print("WeatherFlow Training API starting...") |
| yield |
| print("WeatherFlow Training API shutting down...") |
|
|
| app = FastAPI( |
| title="WeatherFlow Training API", |
| description="Train flow matching models for hurricane wind field prediction", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class TrainingRequest(BaseModel): |
| gdrive_url: str |
| experiment_name: str = "hurricane_flow_matching" |
| hidden_dim: int = 128 |
| num_layers: int = 4 |
| use_attention: bool = True |
| epochs: int = 50 |
| batch_size: int = 4 |
| learning_rate: float = 0.0005 |
| image_size: int = 128 |
| hf_token: Optional[str] = None |
|
|
|
|
| class TrainingResponse(BaseModel): |
| job_id: str |
| status: str |
| message: str |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(): |
| """Simple HTML interface""" |
| return """ |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <title>WeatherFlow Training API</title> |
| <style> |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; } |
| h1 { color: #333; } |
| .endpoint { background: #f5f5f5; padding: 15px; margin: 10px 0; border-radius: 5px; } |
| code { background: #e0e0e0; padding: 2px 6px; border-radius: 3px; } |
| </style> |
| </head> |
| <body> |
| <h1>🌀 WeatherFlow Training API</h1> |
| <p>Train flow matching models for hurricane wind field prediction.</p> |
| |
| <h2>API Endpoints</h2> |
| |
| <div class="endpoint"> |
| <h3>POST /api/train</h3> |
| <p>Start a new training job</p> |
| <p>Body: <code>{"gdrive_url": "...", "epochs": 50, ...}</code></p> |
| </div> |
| |
| <div class="endpoint"> |
| <h3>GET /api/status/{job_id}</h3> |
| <p>Get training status and progress</p> |
| </div> |
| |
| <div class="endpoint"> |
| <h3>POST /api/cancel/{job_id}</h3> |
| <p>Cancel a running training job</p> |
| </div> |
| |
| <div class="endpoint"> |
| <h3>GET /api/jobs</h3> |
| <p>List all training jobs</p> |
| </div> |
| |
| <div class="endpoint"> |
| <h3>GET /health</h3> |
| <p>Health check endpoint</p> |
| </div> |
| |
| <p><a href="/docs">📚 Interactive API Documentation (Swagger UI)</a></p> |
| </body> |
| </html> |
| """ |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy", |
| "cuda_available": torch.cuda.is_available(), |
| "device": "cuda" if torch.cuda.is_available() else "cpu", |
| "active_jobs": len([j for j in training_jobs.values() if j.get("status") == "training"]) |
| } |
|
|
|
|
| @app.post("/api/train", response_model=TrainingResponse) |
| async def start_training(request: TrainingRequest): |
| """Start a new training job""" |
| job_id = str(uuid.uuid4()) |
| |
| training_jobs[job_id] = { |
| "status": "queued", |
| "created_at": datetime.now().isoformat(), |
| "config": request.dict(), |
| "current_epoch": 0, |
| "total_epochs": request.epochs, |
| "train_loss": None, |
| "val_loss": None, |
| "progress": 0, |
| "history": {"epoch": [], "train_loss": [], "val_loss": []}, |
| } |
| |
| |
| def run_training(): |
| try: |
| train_flow_matching( |
| job_id=job_id, |
| gdrive_url=request.gdrive_url, |
| hidden_dim=request.hidden_dim, |
| num_layers=request.num_layers, |
| use_attention=request.use_attention, |
| epochs=request.epochs, |
| batch_size=request.batch_size, |
| learning_rate=request.learning_rate, |
| image_size=request.image_size, |
| hf_token=request.hf_token, |
| experiment_name=request.experiment_name |
| ) |
| except Exception as e: |
| print(f"Training error: {e}") |
| import traceback |
| traceback.print_exc() |
| training_jobs[job_id]["status"] = "failed" |
| training_jobs[job_id]["error"] = str(e) |
| |
| thread = threading.Thread(target=run_training, daemon=True) |
| thread.start() |
| |
| return TrainingResponse( |
| job_id=job_id, |
| status="queued", |
| message="Training job started in background thread" |
| ) |
|
|
|
|
| @app.get("/api/status/{job_id}") |
| async def get_training_status(job_id: str): |
| """Get status of a training job""" |
| if job_id not in training_jobs: |
| raise HTTPException(status_code=404, detail="Job not found") |
| |
| return training_jobs[job_id] |
|
|
|
|
| @app.post("/api/cancel/{job_id}") |
| async def cancel_training(job_id: str): |
| """Cancel a training job""" |
| if job_id not in training_jobs: |
| raise HTTPException(status_code=404, detail="Job not found") |
| |
| training_jobs[job_id]["cancelled"] = True |
| return {"status": "cancelling", "message": "Cancellation requested"} |
|
|
|
|
| @app.get("/api/jobs") |
| async def list_jobs(): |
| """List all training jobs""" |
| return { |
| "jobs": [ |
| {"job_id": job_id, **{k: v for k, v in job.items() if k != "history"}} |
| for job_id, job in training_jobs.items() |
| ] |
| } |
|
|
|
|
| class InferenceRequest(BaseModel): |
| model_repo: str |
| image_url: str |
| hf_token: Optional[str] = None |
| num_steps: int = 50 |
| hidden_dim: int = 128 |
| num_layers: int = 4 |
|
|
|
|
| @app.post("/api/inference") |
| async def run_inference(request: InferenceRequest): |
| """Run inference using a trained model from HuggingFace""" |
| import requests |
| from io import BytesIO |
| from PIL import Image |
| import base64 |
| |
| try: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| from huggingface_hub import hf_hub_download |
| |
| |
| token = request.hf_token if request.hf_token else None |
| |
| try: |
| model_path = hf_hub_download( |
| repo_id=request.model_repo, |
| filename="model.pt", |
| token=token |
| ) |
| except Exception as e: |
| |
| try: |
| model_path = hf_hub_download( |
| repo_id=request.model_repo, |
| filename="pytorch_model.bin", |
| token=token |
| ) |
| except Exception as e2: |
| raise HTTPException(status_code=404, detail=f"Could not download model: {str(e)}. Make sure the model repo exists and you have access. Original error: {str(e2)}") |
| |
| |
| model = FlowMatchingUNet( |
| in_channels=6, |
| out_channels=3, |
| hidden_dim=request.hidden_dim, |
| num_layers=request.num_layers |
| ).to(device) |
| |
| state_dict = torch.load(model_path, map_location=device) |
| model.load_state_dict(state_dict) |
| model.eval() |
| |
| |
| response = requests.get(request.image_url, timeout=30) |
| response.raise_for_status() |
| |
| img = Image.open(BytesIO(response.content)).convert('RGB') |
| img = img.resize((128, 128)) |
| |
| |
| img_tensor = torch.from_numpy(np.array(img)).float() / 255.0 |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| |
| x = torch.randn(1, 3, 128, 128).to(device) |
| |
| |
| dt = 1.0 / request.num_steps |
| for i in range(request.num_steps): |
| t = torch.tensor([i * dt]).to(device) |
| |
| model_input = torch.cat([img_tensor, x], dim=1) |
| |
| v = model(model_input, t) |
| |
| x = x + v * dt |
| |
| |
| output = x.squeeze(0).permute(1, 2, 0).cpu().numpy() |
| output = np.clip(output * 255, 0, 255).astype(np.uint8) |
| output_img = Image.fromarray(output) |
| |
| |
| buffer = BytesIO() |
| output_img.save(buffer, format='PNG') |
| output_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| |
| return { |
| "status": "success", |
| "output_image": f"data:image/png;base64,{output_base64}", |
| "model_repo": request.model_repo, |
| "num_steps": request.num_steps, |
| "device": str(device) |
| } |
| |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/api/models") |
| async def list_available_models(): |
| """List models that have been trained and saved""" |
| models = [] |
| for job_id, job in training_jobs.items(): |
| if job.get("status") == "completed" and job.get("model_url"): |
| models.append({ |
| "job_id": job_id, |
| "model_url": job.get("model_url"), |
| "experiment_name": job.get("config", {}).get("experiment_name"), |
| "completed_at": job.get("completed_at") |
| }) |
| return {"models": models} |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|