Q-WAN-2a / api.py
AEUPH's picture
Update api.py
7e6e9e1 verified
from flask import Blueprint, request, jsonify
import base64
from PIL import Image
import io
import torch
from transformers import CLIPProcessor, CLIPModel
from mlc_llm import ChatModule
import threading
import os
# Create a Blueprint for API routes
app.register_blueprint(api_bp, url_prefix='/api')
# Global instances for models (will be initialized in app.py)
clip_processor = None
clip_model = None
mlc_chat_module = None
mlc_lock = threading.Lock()
@api_bp.route('/health')
def health_check():
llm_status = "loaded" if mlc_chat_module else "not loaded (check logs)"
clip_status = "loaded" if clip_model else "not loaded (check logs)"
return jsonify({
"status": "Quantum-Enhanced WAN 2.1 Backend is running!",
"llm_status": llm_status,
"clip_status": clip_status
})
@api_bp.route('/embed_image', methods=['POST'])
def embed_image():
"""Handle image embedding requests"""
if clip_processor is None or clip_model is None:
return jsonify({"error": "CLIP model not loaded. Check server logs for details."}), 500
try:
data = request.get_json()
if not data:
return jsonify({"error": "Invalid JSON data"}), 400
image_data_url = data.get('image') or data.get('image_url') or data.get('image_data')
if not image_data_url:
return jsonify({"error": "No image data provided. Expected 'image', 'image_url', or 'image_data' field."}), 400
# Handle data URL format
if ',' in image_data_url:
header, encoded = image_data_url.split(",", 1)
else:
# Assume it's raw base64
encoded = image_data_url
# Decode and process image
image_bytes = base64.b64decode(encoded)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
# Normalize embeddings and convert to list for JSON serialization
image_embeddings = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
embeddings_list = image_embeddings.squeeze().cpu().tolist()
return jsonify({
"embeddings": embeddings_list,
"shape": image_embeddings.shape,
"success": True
}), 200
except ValueError as ve:
print(f"Value error embedding image: {ve}")
return jsonify({"error": f"Invalid image data format: {str(ve)}"}), 400
except Exception as e:
print(f"Error embedding image: {e}")
import traceback
traceback.print_exc()
return jsonify({"error": f"Failed to embed image: {str(e)}"}), 500
@api_bp.route('/chat/completions', methods=['POST'])
def chat_completions_endpoint():
if mlc_chat_module is None:
return jsonify({"error": "LLM model not loaded. Check server logs for details."}), 500
data = request.get_json()
prompt = data.get("prompt")
system_message = data.get("system_message", "You are a creative AI assistant for video generation.")
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
try:
full_prompt = f"{system_message}\nUser: {prompt}"
with mlc_lock:
mlc_chat_module.reset_chat()
response = mlc_chat_module.generate(full_prompt)
return jsonify({"completion": response}), 200
except Exception as e:
print(f"Error getting chat completion: {e}")
return jsonify({"error": f"Failed to get chat completion: {str(e)}"}), 500
@api_bp.route('/generate_frame_guidance', methods=['POST'])
def generate_frame_guidance():
# This endpoint provides LLM guidance for the frontend's quantum diffusion.
# It does NOT generate the image itself.
if mlc_chat_module is None or clip_processor is None or clip_model is None:
return jsonify({"error": "One or more AI models not loaded. Check server logs for details."}), 500
data = request.get_json()
image_data_url = data.get('image') # The current frame from the frontend
prompt = data.get('prompt', 'Quantum interpolation')
influence = data.get('influence', 5) # 0-100
entanglement_depth = data.get('depth', 16) # For LLM to consider
frame_number = data.get('frame_number', 0)
if not image_data_url:
return jsonify({"error": "No image data provided"}), 400
try:
# 1. Get CLIP embeddings for the current frame
header, encoded = image_data_url.split(",", 1)
image_bytes = base64.b64decode(encoded)
input_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
clip_inputs = clip_processor(images=input_image, return_tensors="pt")
if torch.cuda.is_available():
clip_inputs = {k: v.to("cuda") for k, v in clip_inputs.items()}
with torch.no_grad():
image_features = clip_model.get_image_features(**clip_inputs)
image_embeddings_np = image_features.squeeze().cpu().numpy()
embedding_snippet = ", ".join([f"{x:.4f}" for x in image_embeddings_np[:10]])
# 2. Use LLM to generate guidance for the next quantum diffusion step
import math
llm_prompt = (
f"You are an AI video director for a quantum diffusion system. Your task is to guide the transformation "
f"of a video frame based on quantum principles and user input. "
f"Given the current frame's visual context (CLIP features: [{embedding_snippet}...]), "
f"the user's creative prompt: '{prompt}', "
f"and the quantum settings (Quantum Influence: {influence}%, Entanglement Depth: {entanglement_depth} layers), "
f"describe *precisely* how the quantum diffusion effect should transform the current frame into frame {frame_number + 1}. "
f"Think of these transformations as manipulating a quantum state that manifests visually. "
f"Higher influence and depth should lead to more pronounced, chaotic, or surreal quantum effects. "
f"Focus on quantifiable visual parameters, including: "
f"color shifts (e.g., 'shift red by +{round(influence/5)}', 'hue rotate {round(influence*1.5)}deg'), "
f"blur (e.g., 'apply gaussian blur radius {max(1, round(influence/10))}'), "
f"glitch/distortion (e.g., 'pixel displacement x-axis random {max(5, round(influence/5))}px', 'chromatic aberration offset {max(1, round(influence/20))}'), "
f"zoom/pan (e.g., 'zoom in {1.00 + influence/2000}x, pan right {round(influence/10)}px'), "
f"pattern overlay (e.g., 'overlay subtle static pattern opacity {influence/200}'), "
f"motion blur (e.g., 'apply motion blur strength {round(entanglement_depth/2)}'), "
f"bloom (e.g., 'add bloom strength {influence/100}'), "
f"noise (e.g., 'add noise amount {influence/50}'), "
f"vignette (e.g., 'add vignette strength {influence/200}'), "
f"or specific quantum-themed visual cues (e.g., 'ripple effect', 'add subtle scanlines opacity {influence/200}', 'invert colors'). "
f"Combine these to create a dynamic, quantum-like visual evolution. Ensure the intensity of effects scales with Influence and Depth. "
f"Be concise and output only the transformation instructions. "
f"Example: 'shift blue by +{round(influence/5)}, apply motion blur strength {round(entanglement_depth/2)}, zoom {1.00 + influence/2000}x, add subtle scanlines opacity {influence/200}'.\n"
f"Transformation Instructions for frame {frame_number + 1}:"
)
llm_guidance = ""
try:
with mlc_lock:
mlc_chat_module.reset_chat()
llm_guidance = mlc_chat_module.generate(llm_prompt)
except Exception as llm_e:
print(f"LLM guidance generation failed: {llm_e}. Using fallback guidance.")
llm_guidance = f"apply subtle glitch effect, shift colors slightly based on quantum influence {influence}%."
print(f"LLM Guidance: {llm_guidance}")
return jsonify({
"guidance": llm_guidance,
"log": (f"Backend provided guidance for frame {frame_number + 1} based on prompt: '{prompt[:50]}...', "
f"influence: {influence}, depth: {entanglement_depth}. LLM guidance: '{llm_guidance[:50]}...'.")
}), 200
except Exception as e:
print(f"Error generating frame guidance: {e}")
return jsonify({"error": f"Failed to generate frame guidance: {str(e)}"}), 500
@api_bp.route('/upload', methods=['POST'])
def upload_file():
try:
# Check if it's a multipart form upload (FormData)
if 'file' in request.files:
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
# Read the image file
img_bytes = file.read()
# Convert to base64 for frontend use
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
# Determine mime type
content_type = file.content_type or 'image/jpeg'
img_data_url = f"data:{content_type};base64,{img_base64}"
return jsonify({
"message": "File uploaded successfully",
"image_url": img_data_url
}), 200
# Check if it's JSON with base64 data
elif request.is_json:
data = request.get_json()
image_data = data.get('image') or data.get('image_url') or data.get('image_data')
if not image_data:
return jsonify({"error": "No image data provided"}), 400
# If already a data URL, return as-is
if image_data.startswith('data:image'):
return jsonify({
"message": "Image data received",
"image_url": image_data
}), 200
# If base64 without header, add it
img_data_url = f"data:image/jpeg;base64,{image_data}"
return jsonify({
"message": "Image data processed",
"image_url": img_data_url
}), 200
else:
return jsonify({"error": "Invalid request format. Send either FormData with 'file' or JSON with 'image' field"}), 400
except Exception as e:
print(f"Error uploading file: {e}")
import traceback
traceback.print_exc()
return jsonify({"error": f"Failed to upload file: {str(e)}"}), 500