Spaces:
Build error
Build error
| import os | |
| import requests | |
| from flask import Flask, request, jsonify | |
| from transformers import pipeline | |
| from PIL import Image | |
| import io | |
| import base64 | |
| # Import for image generation | |
| from diffusers import AutoPipelineForText2Image | |
| app = Flask(__name__) | |
| # --- Configuration --- | |
| GEMMA_MODEL_ID = "google/gemma-4-E2B-it" | |
| IMAGE_GEN_MODEL_ID = "stabilityai/sd-turbo" # A fast, small Stable Diffusion model for demonstration | |
| MAX_NEW_TOKENS = 200 # Adjust as needed for Gemma 4 response length | |
| IMAGE_SIZE = (512, 512) # For generated images | |
| # Determine device for models | |
| # For a CPU-focused Dockerfile, this will default to CPU (-1 or "cpu") | |
| if os.environ.get("USE_GPU", "false").lower() == "true" and os.getenv("CUDA_VISIBLE_DEVICES", "") != "": | |
| device = 0 # Use the first GPU | |
| torch_device_name = "cuda" | |
| else: | |
| device = -1 # Use CPU | |
| torch_device_name = "cpu" | |
| # --- Model Loading --- | |
| gemma_pipeline = None | |
| image_gen_pipeline = None | |
| try: | |
| print(f"Loading Gemma 4 multimodal model: {GEMMA_MODEL_ID} on device {torch_device_name} (pipeline device {device})...") | |
| gemma_pipeline = pipeline("any-to-any", model=GEMMA_MODEL_ID, device=device) | |
| print("Gemma 4 model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading Gemma 4 model: {e}") | |
| try: | |
| print(f"Loading Image Generation model: {IMAGE_GEN_MODEL_ID} on device {torch_device_name}...") | |
| image_gen_pipeline = AutoPipelineForText2Image.from_pretrained(IMAGE_GEN_MODEL_ID).to(torch_device_name) | |
| # Only enable xformers if on GPU | |
| if torch_device_name == "cuda": | |
| try: | |
| # Note: xformers might require a specific CUDA version or manual installation. | |
| # If this line causes issues, comment it out. | |
| image_gen_pipeline.enable_xformers_memory_efficient_attention() # Optional: for memory efficiency on GPU | |
| print("xFormers enabled for image generation.") | |
| except ImportError: | |
| print("xFormers not installed or not available, skipping memory efficient attention.") | |
| print("Image Generation model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading Image Generation model: {e}") | |
| # --- Helper Functions --- | |
| def encode_image_to_base64(image: Image.Image) -> str: | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| # --- API Endpoints --- | |
| def home(): | |
| return "Multimodal AI (Gemma 4) and Image Generation API is running. Use /gemma-predict or /generate-image." | |
| def gemma_predict(): | |
| """ | |
| Endpoint for Gemma 4 multimodal text generation (image + text -> text). | |
| """ | |
| if gemma_pipeline is None: | |
| return jsonify({"error": "Gemma 4 model not loaded. Please check server logs."}), 503 | |
| try: | |
| data = request.json | |
| if not data: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| image_base64 = data.get('image_base64') | |
| text_prompt = data.get('text_prompt', '') | |
| if not image_base64 and not text_prompt: | |
| return jsonify({"error": "At least 'image_base64' or 'text_prompt' must be provided"}), 400 | |
| messages = [] | |
| if image_base64: | |
| try: | |
| image_bytes = base64.b64decode(image_base64) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| messages.append({ | |
| "type": "image", | |
| "image": image, | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": f"Invalid image_base64 provided: {e}"}), 400 | |
| if text_prompt: | |
| messages.append({ | |
| "type": "text", | |
| "text": text_prompt, | |
| }) | |
| if not messages: | |
| return jsonify({"error": "No valid input (image or text) provided for Gemma."}), 400 | |
| full_messages = [ | |
| { | |
| "role": "user", | |
| "content": messages, | |
| } | |
| ] | |
| output = gemma_pipeline(full_messages, max_new_tokens=MAX_NEW_TOKENS, return_full_text=False) | |
| if output and len(output) > 0 and "generated_text" in output[0]: | |
| return jsonify({"prediction": output[0]["generated_text"]}) | |
| else: | |
| return jsonify({"error": "Gemma 4 model did not return generated text."}), 500 | |
| except Exception as e: | |
| print(f"Error during Gemma 4 prediction: {e}") | |
| return jsonify({"error": f"An error occurred during Gemma 4 prediction: {str(e)}"}), 500 | |
| def generate_image(): | |
| """ | |
| Endpoint for text-to-image generation. | |
| """ | |
| if image_gen_pipeline is None: | |
| return jsonify({"error": "Image generation model not loaded. Please check server logs."}), 503 | |
| try: | |
| data = request.json | |
| if not data: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| prompt = data.get('prompt') | |
| if not prompt: | |
| return jsonify({"error": "Missing 'prompt' for image generation."}), 400 | |
| # Generate image | |
| # You can add more parameters here like num_inference_steps, guidance_scale | |
| generated_image = image_gen_pipeline(prompt).images[0] | |
| # Encode the generated image to base64 | |
| image_base64 = encode_image_to_base64(generated_image) | |
| return jsonify({"image_base64": image_base64, "prompt": prompt}) | |
| except Exception as e: | |
| print(f"Error during image generation: {e}") | |
| return jsonify({"error": f"An error occurred during image generation: {str(e)}"}), 500 | |
| def status(): | |
| """ | |
| Checks the status of both AI models. | |
| """ | |
| gemma_status = "ready" if gemma_pipeline else "not_loaded" | |
| image_gen_status = "ready" if image_gen_pipeline else "not_loaded" | |
| return jsonify({ | |
| "gemma_4_model_id": GEMMA_MODEL_ID, | |
| "gemma_4_status": gemma_status, | |
| "image_gen_model_id": IMAGE_GEN_MODEL_ID, | |
| "image_gen_status": image_gen_status, | |
| "device_used": torch_device_name | |
| }) | |
| # --- Main Execution --- | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=5000, debug=True) |