from flask import Flask, request, jsonify, render_template_string from flask_cors import CORS import torch import torch.nn as nn import time import os # Force PyTorch to use single thread (fixes slow inference on throttled CPUs) torch.set_num_threads(1) torch.set_num_interop_threads(1) os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' # Import our model classes class CharTokenizer: def __init__(self): chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ " chars += "0123456789.,!?¿áéíóúñÁÉÍÓÚÑ" self.char_to_idx = {c: i+1 for i, c in enumerate(chars)} self.idx_to_char = {i+1: c for i, c in enumerate(chars)} self.vocab_size = len(self.char_to_idx) + 1 def encode(self, text, max_len=100): indices = [self.char_to_idx.get(c, 0) for c in text[:max_len]] indices += [0] * (max_len - len(indices)) return torch.tensor(indices, dtype=torch.long) class AtacamaWeatherOracle(nn.Module): def __init__(self, vocab_size=100, embed_dim=16, hidden_dim=32): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True) self.classifier = nn.Linear(hidden_dim, 2) def forward(self, x): embedded = self.embedding(x) _, (hidden, _) = self.lstm(embedded) logits = self.classifier(hidden.squeeze(0)) return logits # Initialize Flask app app = Flask(__name__) CORS(app) # Load the trained model print("Loading Atacama Weather Oracle...") load_start = time.time() tokenizer = CharTokenizer() model = AtacamaWeatherOracle(vocab_size=tokenizer.vocab_size) checkpoint = torch.load('atacama_weather_oracle.pth', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() load_time = time.time() - load_start print(f"✅ Oracle loaded and ready! (took {load_time:.3f}s)") # HTML template for the web interface HTML_TEMPLATE = """ Is It Raining in Atacama?

atacama

An ultra-small language model 7,762 parameters 30KB 99.9% certain

""" @app.route('/') def home(): return render_template_string(HTML_TEMPLATE) @app.route('/ask', methods=['POST']) def ask(): request_start = time.time() data = request.json question = data.get('question', '') # Ask the oracle with granular timing t0 = time.time() tokens = tokenizer.encode(question).unsqueeze(0) t1 = time.time() with torch.no_grad(): logits = model(tokens) t2 = time.time() probs = torch.softmax(logits, dim=1)[0] t3 = time.time() prob_no_rain = probs[0].item() prob_rain = probs[1].item() t4 = time.time() if prob_no_rain > 0.999: answer = "No." confidence = "Absolute certainty" elif prob_no_rain > 0.99: answer = "No. (But I admire your optimism)" confidence = "Very high confidence" elif prob_no_rain > 0.9: answer = "Almost certainly not." confidence = "High confidence" else: answer = "Historically unprecedented... but no." confidence = "Moderate confidence" total_time = time.time() - request_start # Log granular timing to server console print(f"TIMING: tokenize={((t1-t0)*1000):.1f}ms, model={((t2-t1)*1000):.1f}ms, softmax={((t3-t2)*1000):.1f}ms, extract={((t4-t3)*1000):.1f}ms, total={total_time*1000:.1f}ms") return jsonify({ 'answer': answer, 'confidence': confidence, 'prob_no_rain': prob_no_rain, 'prob_rain': prob_rain, 'inference_ms': f"{total_time*1000:.1f}", 'debug': f"tok={((t1-t0)*1000):.0f}ms model={((t2-t1)*1000):.0f}ms soft={((t3-t2)*1000):.0f}ms" }) @app.route('/health') def health(): """Health check endpoint - also useful for keeping the container warm""" return jsonify({'status': 'ok', 'model': 'loaded'}) if __name__ == '__main__': import os port = int(os.environ.get('PORT', 5000)) app.run(host='0.0.0.0', port=port)