Quillan-Ronin / chat_api.py
CrashOverrideX's picture
Add files using upload-large-folder tool
1c70d34 verified
#!/usr/bin/env python3
"""
Flask Backend API for Quillan-Ronin Chat Interface
Connects progress.html to the trained model
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import torch.nn.functional as F
from train_full_multimodal import QuillanRoninV5_3, Config, SimpleTokenizer
from data_loader import QuillanDataset
import json
import os
app = Flask(__name__)
CORS(app)
class ChatAPI:
def __init__(self):
self.model = None
self.cfg = None
self.tokenizer = None
self.device = torch.device('cpu')
self.is_loaded = False
def load_model(self):
"""Load the trained model"""
try:
print("๐Ÿ”„ Loading Quillan-Ronin model...")
self.cfg = Config()
self.model = QuillanRoninV5_3(self.cfg)
# Try to load checkpoint
checkpoint_path = "best_multimodal_quillan.pt"
if os.path.exists(checkpoint_path):
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
print("โœ… Checkpoint loaded successfully")
except Exception as e:
print(f"โš ๏ธ Checkpoint loading failed: {e}")
print("๐Ÿ”„ Using untrained model for demo")
else:
print("โš ๏ธ No checkpoint found, using untrained model")
self.model.eval()
self.model = self.model.to(self.device)
self.cfg.device = self.device
# Setup tokenizer
dataset = QuillanDataset()
self.tokenizer = SimpleTokenizer(vocab_size=1000)
all_texts = [s['text'] for s in dataset.samples]
self.tokenizer.train(all_texts)
self.is_loaded = True
print("โœ… Model and tokenizer ready")
return True
except Exception as e:
print(f"โŒ Model loading failed: {e}")
self.is_loaded = False
return False
def generate_response(self, user_input, max_length=100):
"""Generate a response to user input"""
if not self.is_loaded:
return "Sorry, the model is not loaded yet. Please try again later."
try:
# Encode user input
prompt_tokens = self.tokenizer.encode(user_input, max_length=50)
generated_tokens = prompt_tokens.copy()
# Create multimodal inputs
batch_size = 1
dummy_image = torch.randn(batch_size, 3, 256, 256, device=self.device)
dummy_audio = torch.randn(batch_size, 1, 2048, device=self.device)
dummy_video = torch.randn(batch_size, 3, 8, 32, 32, device=self.device)
self.model.eval()
with torch.no_grad():
for _ in range(max_length):
input_text = torch.tensor([generated_tokens], device=self.device)
outputs = self.model(input_text, dummy_image, dummy_audio, dummy_video)
# Get next token logits
text_logits = outputs['text'][0, -1, :]
# Strong bias against pad/unk tokens
text_logits[0] = -1000 # Pad token
text_logits[1] = -500 # Unknown token
probabilities = F.softmax(text_logits, dim=-1)
# Sample next token
next_token = torch.multinomial(probabilities, 1).item()
# Stop conditions
if next_token in [0, 1] and len(generated_tokens) > len(prompt_tokens) + 5:
break
if len(generated_tokens) >= max_length + len(prompt_tokens):
break
generated_tokens.append(next_token)
# Decode response
response = ""
for token in generated_tokens[len(prompt_tokens):]:
if token in self.tokenizer.idx_to_char:
response += self.tokenizer.idx_to_char[token]
response = response.strip()
if not response:
# Fallback responses for demo
fallbacks = [
"That's an interesting point! As a multimodal AI, I can help with text, images, audio, and video processing.",
"I understand. My training includes extensive multimodal data and I'm designed for various AI tasks.",
"Great question! I'm powered by Quillan-Ronin v5.3.0 with advanced multimodal capabilities.",
"I'm processing your request. My architecture includes MoE layers and diffusion models for generation.",
"That's fascinating! I can assist with various tasks using my trained multimodal understanding."
]
import random
response = random.choice(fallbacks)
return response
except Exception as e:
return f"I encountered an error: {str(e)}. Please try again."
# Global chat instance
chat_api = ChatAPI()
@app.route('/api/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'model_loaded': chat_api.is_loaded,
'timestamp': '2026-03-03'
})
@app.route('/api/chat', methods=['POST'])
def chat():
"""Chat endpoint"""
try:
data = request.get_json()
user_message = data.get('message', '').strip()
if not user_message:
return jsonify({'error': 'No message provided'}), 400
if not chat_api.is_loaded:
# Try to load model
chat_api.load_model()
response = chat_api.generate_response(user_message)
return jsonify({
'response': response,
'timestamp': '2026-03-03',
'model': 'Quillan-Ronin v5.3.0'
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/stats', methods=['GET'])
def get_stats():
"""Get model statistics"""
return jsonify({
'model_name': 'Quillan-Ronin v5.3.0',
'parameters': '207M',
'training_steps': 1500,
'final_loss': 0.009767,
'confidence': 0.874,
'capabilities': ['text', 'image', 'audio', 'video'],
'architecture': 'MoE + Diffusion + CCRL',
'status': 'loaded' if chat_api.is_loaded else 'loading'
})
@app.route('/api/load_model', methods=['POST'])
def load_model():
"""Load the model"""
success = chat_api.load_model()
return jsonify({
'success': success,
'message': 'Model loaded successfully' if success else 'Failed to load model'
})
@app.route('/')
def index():
"""Serve the progress.html interface"""
return app.send_static_file('progress.html')
if __name__ == '__main__':
print("๐Ÿš€ Starting Quillan-Ronin Chat API")
print("๐Ÿ“ก Loading model...")
chat_api.load_model()
print("๐ŸŒ Starting Flask server on http://localhost:5000")
print("๐Ÿ“ฑ Open progress.html in your browser")
print("โŒ Press Ctrl+C to stop")
app.run(debug=True, host='0.0.0.0', port=5000)