|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
import os |
|
|
import base64 |
|
|
from PIL import Image |
|
|
import io |
|
|
import torch |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
from mlc_llm import ChatModule |
|
|
import threading |
|
|
import numpy as np |
|
|
import hashlib |
|
|
import time |
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MLC_MODEL_ARTIFACTS_DIR = os.getenv("MLC_MODEL_ARTIFACTS_DIR", "./backend/model_artifacts") |
|
|
MLC_MODEL_NAME = os.getenv("MLC_MODEL_NAME", "Llama-2-7b-chat-hf-q4f16_1") |
|
|
MLC_MODEL_PATH = os.path.join(MLC_MODEL_ARTIFACTS_DIR, MLC_MODEL_NAME) |
|
|
|
|
|
|
|
|
CLIP_MODEL_NAME = "openai/clip-vit-base-patch32" |
|
|
|
|
|
|
|
|
clip_processor = None |
|
|
clip_model = None |
|
|
mlc_chat_module = None |
|
|
mlc_lock = threading.Lock() |
|
|
|
|
|
def load_mlc_llm_model(): |
|
|
global mlc_chat_module |
|
|
if mlc_chat_module is None: |
|
|
print(f"Attempting to load LLM model: {MLC_MODEL_NAME} from {MLC_MODEL_PATH}...") |
|
|
try: |
|
|
if not os.path.exists(MLC_MODEL_PATH): |
|
|
print(f"Error: MLC LLM model path not found: {MLC_MODEL_PATH}") |
|
|
print("Please ensure the MLC LLM model is downloaded and compiled in the specified path.") |
|
|
print("Refer to installation instructions for mlc-llm and model download commands.") |
|
|
return None |
|
|
|
|
|
mlc_chat_module = ChatModule(model=MLC_MODEL_NAME, model_path=MLC_MODEL_PATH) |
|
|
print("MLC LLM model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading MLC LLM model: {e}") |
|
|
mlc_chat_module = None |
|
|
return mlc_chat_module |
|
|
|
|
|
def load_clip_model(): |
|
|
global clip_processor, clip_model |
|
|
if clip_model is None: |
|
|
print(f"Attempting to load CLIP model: {CLIP_MODEL_NAME}...") |
|
|
try: |
|
|
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) |
|
|
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
clip_model.to("cuda") |
|
|
print("CLIP model moved to CUDA.") |
|
|
print("CLIP model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading CLIP model: {e}") |
|
|
clip_processor, clip_model = None, None |
|
|
return clip_processor, clip_model |
|
|
|
|
|
|
|
|
with app.app_context(): |
|
|
if mlc_chat_module is None: |
|
|
load_mlc_llm_model() |
|
|
if clip_model is None: |
|
|
load_clip_model() |
|
|
|
|
|
@app.route('/') |
|
|
def health_check(): |
|
|
llm_status = "loaded" if mlc_chat_module else "not loaded (check logs)" |
|
|
clip_status = "loaded" if clip_model else "not loaded (check logs)" |
|
|
return jsonify({ |
|
|
"status": "Quantum-Enhanced WAN 2.1 Backend is running!", |
|
|
"llm_status": llm_status, |
|
|
"clip_status": clip_status |
|
|
}) |
|
|
|
|
|
@app.route('/embed_image', methods=['POST']) |
|
|
def embed_image(): |
|
|
if clip_processor is None or clip_model is None: |
|
|
return jsonify({"error": "CLIP model not loaded. Check server logs for details."}), |
|
|
500 |
|
|
|
|
|
data = request.get_json() |
|
|
image_data_url = data.get('image') |
|
|
|
|
|
if not image_data_url: |
|
|
return jsonify({"error": "No image data provided"}), 400 |
|
|
|
|
|
try: |
|
|
header, encoded = image_data_url.split(",", 1) |
|
|
image_bytes = base64.b64decode(encoded) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
inputs = clip_processor(images=image, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = clip_model.get_image_features(**inputs) |
|
|
|
|
|
|
|
|
image_embeddings = image_features / image_features.norm(p=2, dim=-1, keepdim=True) |
|
|
return jsonify({"embeddings": image_embeddings.squeeze().tolist()}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error embedding image: {e}") |
|
|
return jsonify({"error": f"Failed to embed image: {str(e)}"}), 500 |
|
|
|
|
|
@app.route('/chat/completions', methods=['POST']) |
|
|
def chat_completions_endpoint(): |
|
|
if mlc_chat_module is None: |
|
|
return jsonify({"error": "LLM model not loaded. Check server logs for details."}), |
|
|
500 |
|
|
|
|
|
data = request.get_json() |
|
|
prompt = data.get("prompt") |
|
|
system_message = data.get("system_message", "You are a creative AI assistant for video generation.") |
|
|
|
|
|
if not prompt: |
|
|
return jsonify({"error": "Prompt is required"}), 400 |
|
|
|
|
|
try: |
|
|
full_prompt = f"{system_message}\nUser: {prompt}" |
|
|
|
|
|
with mlc_lock: |
|
|
mlc_chat_module.reset_chat() |
|
|
response = mlc_chat_module.generate(full_prompt) |
|
|
|
|
|
return jsonify({"completion": response}) |
|
|
except Exception as e: |
|
|
print(f"Error getting chat completion: {e}") |
|
|
return jsonify({"error": f"Failed to get chat completion: {str(e)}"}), 500 |
|
|
|
|
|
@app.route('/generate_frame_guidance', methods=['POST']) |
|
|
def generate_frame_guidance(): |
|
|
|
|
|
|
|
|
if mlc_chat_module is None or clip_processor is None or clip_model is None: |
|
|
return jsonify({"error": "One or more AI models not loaded. Check server logs for details."}), |
|
|
500 |
|
|
|
|
|
data = request.get_json() |
|
|
image_data_url = data.get('image') |
|
|
prompt = data.get('prompt', 'Quantum interpolation') |
|
|
influence = data.get('influence', 5) |
|
|
entanglement_depth = data.get('depth', 16) |
|
|
frame_number = data.get('frame_number', 0) |
|
|
|
|
|
if not image_data_url: |
|
|
return jsonify({"error": "No image data provided"}), 400 |
|
|
|
|
|
try: |
|
|
|
|
|
header, encoded = image_data_url.split(",", 1) |
|
|
image_bytes = base64.b64decode(encoded) |
|
|
input_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
clip_inputs = clip_processor(images=input_image, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
clip_inputs = {k: v.to("cuda") for k, v in clip_inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_features = clip_model.get_image_features(**clip_inputs) |
|
|
image_embeddings_np = image_features.squeeze().cpu().numpy() |
|
|
embedding_snippet = ", ".join([f"{x:.4f}" for x in image_embeddings_np[:10]]) |
|
|
|
|
|
|
|
|
llm_prompt = ( |
|
|
f"You are an AI video director for a quantum diffusion system. Your task is to guide the transformation " |
|
|
f"of a video frame based on quantum principles and user input. " |
|
|
f"Given the current frame's visual context (CLIP features: [{embedding_snippet}...]), " |
|
|
f"the user's creative prompt: '{prompt}', " |
|
|
f"and the quantum settings (Quantum Influence: {influence}%, Entanglement Depth: {entanglement_depth} layers), " |
|
|
f"describe *precisely* how the quantum diffusion effect should transform the current frame into frame {frame_number + 1}. " |
|
|
f"Think of these transformations as manipulating a quantum state that manifests visually. " |
|
|
f"Higher influence and depth should lead to more pronounced, chaotic, or surreal quantum effects. " |
|
|
f"Focus on quantifiable visual parameters, including: " |
|
|
f"color shifts (e.g., 'shift red by +{Math.round(influence/5)}', 'hue rotate {Math.round(influence*1.5)}deg'), " |
|
|
f"blur (e.g., 'apply gaussian blur radius {Math.max(1, Math.round(influence/10))}'), " |
|
|
f"glitch/distortion (e.g., 'pixel displacement x-axis random {Math.max(5, Math.round(influence/5))}px', 'chromatic aberration offset {Math.max(1, Math.round(influence/20))}'), " |
|
|
f"zoom/pan (e.g., 'zoom in {1.00 + influence/2000}x, pan right {Math.round(influence/10)}px'), " |
|
|
f"pattern overlay (e.g., 'overlay subtle static pattern opacity {influence/200}'), " |
|
|
f"motion blur (e.g., 'apply motion blur strength {Math.round(entanglement_depth/2)}'), " |
|
|
f"bloom (e.g., 'add bloom strength {influence/100}'), " |
|
|
f"noise (e.g., 'add noise amount {influence/50}'), " |
|
|
f"vignette (e.g., 'add vignette strength {influence/200}'), " |
|
|
f"or specific quantum-themed visual cues (e.g., 'ripple effect', 'add subtle scanlines opacity {influence/200}', 'invert colors'). " |
|
|
f"Combine these to create a dynamic, quantum-like visual evolution. Ensure the intensity of effects scales with Influence and Depth. " |
|
|
f"Be concise and output only the transformation instructions. " |
|
|
f"Example: 'shift blue by +{Math.round(influence/5)}, apply motion blur strength {Math.round(entanglement_depth/2)}, zoom {1.00 + influence/2000}x, add subtle scanlines opacity {influence/200}'.\n" |
|
|
f"Transformation Instructions for frame {frame_number + 1}:" |
|
|
) |
|
|
|
|
|
llm_guidance = "" |
|
|
try: |
|
|
with mlc_lock: |
|
|
mlc_chat_module.reset_chat() |
|
|
llm_guidance = mlc_chat_module.generate(llm_prompt) |
|
|
except Exception as llm_e: |
|
|
print(f"LLM guidance generation failed: {llm_e}. Using fallback guidance.") |
|
|
llm_guidance = f"apply subtle glitch effect, shift colors slightly based on quantum influence {influence}%." |
|
|
|
|
|
print(f"LLM Guidance: {llm_guidance}") |
|
|
|
|
|
return jsonify({ |
|
|
"guidance": llm_guidance, |
|
|
"log": (f"Backend provided guidance for frame {frame_number + 1} based on prompt: '{prompt[:50]}...', " |
|
|
f"influence: {influence}, depth: {entanglement_depth}. LLM guidance: '{llm_guidance[:50]}...'.") |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error generating frame guidance: {e}") |
|
|
return jsonify({"error": f"Failed to generate frame guidance: {str(e)}"}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
if not os.path.exists(MLC_MODEL_ARTIFACTS_DIR): |
|
|
os.makedirs(MLC_MODEL_ARTIFACTS_DIR) |
|
|
print(f"Created model artifacts directory: {MLC_MODEL_ARTIFACTS_DIR}") |
|
|
|
|
|
app.run(debug=True, host='0.0.0.0', port=5000) |
|
|
|