import os import gradio as gr import torch from PIL import Image import numpy as np import random import time import uuid import json from transformers import ( pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BlipProcessor, BlipForConditionalGeneration, AutoImageProcessor, StableDiffusionPipeline ) from diffusers import DiffusionPipeline, StableVideoDiffusionPipeline import logging from typing import Dict, List, Optional, Union, Any from pydantic import BaseModel, Field # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Configure cache directory os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" os.environ["HF_HOME"] = "/tmp/hf_home" # Define API Models with Pydantic class ContentIdeaRequest(BaseModel): prompt: str = Field(..., description="Initial prompt for content idea generation") num_ideas: int = Field(3, description="Number of ideas to generate", ge=1, le=5) creativity: float = Field(0.7, description="Creativity level (0.0-1.0)", ge=0.0, le=1.0) class ContentIdeaResponse(BaseModel): ideas: List[str] = Field(..., description="Generated content ideas") processing_time: float = Field(..., description="Processing time in seconds") model_used: str = Field(..., description="Model used for generation") class TextToImageRequest(BaseModel): prompt: str = Field(..., description="Text prompt for image generation") negative_prompt: Optional[str] = Field(None, description="Negative prompt for image generation") width: int = Field(512, description="Image width") height: int = Field(512, description="Image height") num_inference_steps: int = Field(30, description="Number of inference steps") guidance_scale: float = Field(7.5, description="Guidance scale") class ImageToVideoRequest(BaseModel): image_url: str = Field(..., description="URL of the image to convert to video") motion_strength: float = Field(0.5, description="Motion strength for video generation", ge=0.0, le=1.0) num_frames: int = Field(16, description="Number of frames for the video") class ImageAnalysisRequest(BaseModel): image_url: str = Field(..., description="URL of the image to analyze") analysis_type: str = Field("caption", description="Type of analysis: caption, objects, or detailed") # Cache for models to avoid reloading MODEL_CACHE = {} def get_model(model_name, model_class, tokenizer_class=None, processor_class=None): """Load a model from cache or download it""" if model_name not in MODEL_CACHE: try: logger.info(f"Loading model: {model_name}") start_time = time.time() if processor_class: processor = processor_class.from_pretrained(model_name) model = model_class.from_pretrained(model_name) MODEL_CACHE[model_name] = {"model": model, "processor": processor} elif tokenizer_class: tokenizer = tokenizer_class.from_pretrained(model_name) model = model_class.from_pretrained(model_name, torch_dtype=torch.float16) MODEL_CACHE[model_name] = {"model": model, "tokenizer": tokenizer} else: model = model_class.from_pretrained(model_name) MODEL_CACHE[model_name] = {"model": model} logger.info(f"Model {model_name} loaded in {time.time() - start_time:.2f} seconds") except Exception as e: logger.error(f"Error loading model {model_name}: {str(e)}") raise return MODEL_CACHE[model_name] # Models configuration MODELS = { "text_generator": "distilgpt2", # Lightweight text generation model "text_to_image": "stabilityai/stable-diffusion-2-base", # Free and balanced model "image_to_video": "stabilityai/stable-video-diffusion-img2vid-xt", # Image to video model "image_captioning": "Salesforce/blip-image-captioning-base", # Image captioning model "text_summarization": "facebook/bart-large-cnn", # For summarizing/refining ideas } # ------------ API ROUTE IMPLEMENTATIONS ------------ def generate_content_ideas(request: ContentIdeaRequest) -> ContentIdeaResponse: """Generate content ideas based on a prompt""" try: start_time = time.time() # Get text generation model model_name = MODELS["text_generator"] model_data = get_model(model_name, AutoModelForCausalLM, AutoTokenizer) # Setup generation parameters temperature = 0.5 + (request.creativity * 0.5) # Scale creativity to temperature max_length = 100 + int(request.creativity * 100) # Longer responses for higher creativity generator = pipeline( "text-generation", model=model_data["model"], tokenizer=model_data["tokenizer"], device=0 if torch.cuda.is_available() else -1 ) # Generate multiple ideas ideas = [] for _ in range(request.num_ideas): prompt = f"Generate a creative content idea based on: {request.prompt}\nContent idea:" result = generator( prompt, max_length=max_length, temperature=temperature, num_return_sequences=1, do_sample=True ) # Extract the generated idea and clean it generated_text = result[0]["generated_text"] idea = generated_text.split("Content idea:")[1].strip() ideas.append(idea) processing_time = time.time() - start_time return ContentIdeaResponse( ideas=ideas, processing_time=processing_time, model_used=model_name ) except Exception as e: logger.error(f"Error generating content ideas: {str(e)}") raise gr.Error(f"Failed to generate content ideas: {str(e)}") def text_to_image(request: TextToImageRequest) -> str: """Convert text prompt to image""" try: model_name = MODELS["text_to_image"] # Load StableDiffusionPipeline if not in cache if model_name not in MODEL_CACHE: pipe = StableDiffusionPipeline.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) if torch.cuda.is_available(): pipe = pipe.to("cuda") MODEL_CACHE[model_name] = {"pipeline": pipe} pipe = MODEL_CACHE[model_name]["pipeline"] # Generate image image = pipe( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale ).images[0] # Save image output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) filename = f"{output_dir}/image_{uuid.uuid4()}.png" image.save(filename) return filename except Exception as e: logger.error(f"Error in text-to-image conversion: {str(e)}") raise gr.Error(f"Failed to generate image: {str(e)}") def image_to_video(request: ImageToVideoRequest) -> str: """Convert image to video with motion""" try: image_path = request.image_url if not os.path.exists(image_path): raise gr.Error(f"Image file not found: {image_path}") # Load image image = Image.open(image_path) # Load StableVideoDiffusionPipeline model_name = MODELS["image_to_video"] if model_name not in MODEL_CACHE: pipe = StableVideoDiffusionPipeline.from_pretrained( model_name, torch_dtype=torch.float16, variant="fp16" ) if torch.cuda.is_available(): pipe = pipe.to("cuda") MODEL_CACHE[model_name] = {"pipeline": pipe} pipe = MODEL_CACHE[model_name]["pipeline"] # Generate video frames result = pipe( image, motion_bucket_id=int(request.motion_strength * 100), # Convert to bucket ID num_frames=request.num_frames ).frames[0] # Save video frames as GIF output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) filename = f"{output_dir}/video_{uuid.uuid4()}.gif" result.save( filename, save_all=True, append_images=result[1:], optimize=True, duration=100, # ms between frames loop=0 # Loop forever ) return filename except Exception as e: logger.error(f"Error in image-to-video conversion: {str(e)}") raise gr.Error(f"Failed to generate video: {str(e)}") def analyze_image(request: ImageAnalysisRequest) -> Dict[str, Any]: """Analyze image content""" try: image_path = request.image_url if not os.path.exists(image_path): raise gr.Error(f"Image file not found: {image_path}") # Load image captioning model model_name = MODELS["image_captioning"] if model_name not in MODEL_CACHE: processor = BlipProcessor.from_pretrained(model_name) model = BlipForConditionalGeneration.from_pretrained(model_name) MODEL_CACHE[model_name] = { "processor": processor, "model": model } processor = MODEL_CACHE[model_name]["processor"] model = MODEL_CACHE[model_name]["model"] # Process image image = Image.open(image_path).convert('RGB') inputs = processor(image, return_tensors="pt") # Generate caption out = model.generate(**inputs) caption = processor.decode(out[0], skip_special_tokens=True) # Return different analysis based on type if request.analysis_type == "caption": return {"caption": caption} elif request.analysis_type == "objects": # This is a simplified approach - object detection would require a different model keywords = caption.replace(",", "").replace(".", "").split() objects = [word for word in keywords if len(word) > 3] return {"caption": caption, "objects": objects} elif request.analysis_type == "detailed": # For detailed analysis, we would enhance the caption enhanced_caption = f"The image shows {caption}. This appears to be a {caption.split()[0]} scene." return { "caption": caption, "detailed_description": enhanced_caption, "analysis_type": "basic visual elements" } except Exception as e: logger.error(f"Error in image analysis: {str(e)}") raise gr.Error(f"Failed to analyze image: {str(e)}") # ------------ GRADIO INTERFACE ------------ def create_gradio_blocks(): """Create Gradio interface with tab-based organization""" with gr.Blocks(title="Multi-Modal Content API") as demo: gr.Markdown("# 🎨 Multi-Modal Content Generation API") gr.Markdown("Generate content ideas, convert text to images, images to videos, and analyze images.") with gr.Tabs(): # Content Idea Generator Tab with gr.TabItem("Content Idea Generator"): with gr.Row(): with gr.Column(): idea_prompt = gr.Textbox(label="Prompt for Content Ideas", placeholder="Enter a starting point for content ideas...") num_ideas = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of Ideas") creativity = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Creativity Level") idea_generate_btn = gr.Button("Generate Content Ideas") with gr.Column(): idea_output = gr.JSON(label="Generated Content Ideas") # Text to Image Tab with gr.TabItem("Text to Image"): with gr.Row(): with gr.Column(): img_prompt = gr.Textbox(label="Image Prompt", placeholder="Describe the image you want to create...") img_negative_prompt = gr.Textbox(label="Negative Prompt (Optional)", placeholder="What to exclude from the image...") with gr.Row(): img_width = gr.Slider(minimum=256, maximum=768, value=512, step=64, label="Width") img_height = gr.Slider(minimum=256, maximum=768, value=512, step=64, label="Height") with gr.Row(): img_steps = gr.Slider(minimum=10, maximum=50, value=30, step=1, label="Inference Steps") img_guidance = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale") img_generate_btn = gr.Button("Generate Image") with gr.Column(): img_output = gr.Image(label="Generated Image") # Image to Video Tab with gr.TabItem("Image to Video"): with gr.Row(): with gr.Column(): vid_image = gr.Image(label="Upload Image", type="filepath") vid_motion = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Motion Strength") vid_frames = gr.Slider(minimum=8, maximum=24, value=16, step=8, label="Number of Frames") vid_generate_btn = gr.Button("Generate Video") with gr.Column(): vid_output = gr.Video(label="Generated Video") # Image Analysis Tab with gr.TabItem("Image Analysis"): with gr.Row(): with gr.Column(): analysis_image = gr.Image(label="Upload Image for Analysis", type="filepath") analysis_type = gr.Radio( ["caption", "objects", "detailed"], label="Analysis Type", value="caption" ) analysis_btn = gr.Button("Analyze Image") with gr.Column(): analysis_output = gr.JSON(label="Analysis Results") # Set up event handlers idea_generate_btn.click( fn=lambda prompt, num, creativity: generate_content_ideas( ContentIdeaRequest(prompt=prompt, num_ideas=num, creativity=creativity) ), inputs=[idea_prompt, num_ideas, creativity], outputs=idea_output ) img_generate_btn.click( fn=lambda prompt, neg_prompt, width, height, steps, guidance: text_to_image( TextToImageRequest( prompt=prompt, negative_prompt=neg_prompt, width=width, height=height, num_inference_steps=steps, guidance_scale=guidance ) ), inputs=[img_prompt, img_negative_prompt, img_width, img_height, img_steps, img_guidance], outputs=img_output ) vid_generate_btn.click( fn=lambda image, motion, frames: image_to_video( ImageToVideoRequest( image_url=image, motion_strength=motion, num_frames=frames ) ), inputs=[vid_image, vid_motion, vid_frames], outputs=vid_output ) analysis_btn.click( fn=lambda image, analysis_type: analyze_image( ImageAnalysisRequest( image_url=image, analysis_type=analysis_type ) ), inputs=[analysis_image, analysis_type], outputs=analysis_output ) return demo # ------------ API ENDPOINTS FOR PROGRAMMATIC ACCESS ------------ def build_fastapi(): """Create FastAPI endpoints for programmatic access""" from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware app = FastAPI( title="Multi-Modal Content API", description="API for content generation, image creation, video generation, and image analysis", version="1.0.0" ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def read_root(): return {"message": "Welcome to the Multi-Modal Content API"} @app.post("/api/generate-ideas", response_model=ContentIdeaResponse) def api_generate_ideas(request: ContentIdeaRequest): return generate_content_ideas(request) @app.post("/api/text-to-image") def api_text_to_image(request: TextToImageRequest): try: image_path = text_to_image(request) return FileResponse( path=image_path, media_type="image/png", filename=os.path.basename(image_path) ) except Exception as e: return JSONResponse( status_code=500, content={"error": str(e)} ) @app.post("/api/image-to-video") async def api_image_to_video( motion_strength: float = Form(0.5), num_frames: int = Form(16), image: UploadFile = File(...) ): try: # Save uploaded image image_path = f"uploads/{uuid.uuid4()}.png" os.makedirs("uploads", exist_ok=True) with open(image_path, "wb") as f: f.write(await image.read()) # Generate video request = ImageToVideoRequest( image_url=image_path, motion_strength=motion_strength, num_frames=num_frames ) video_path = image_to_video(request) return FileResponse( path=video_path, media_type="image/gif", filename=os.path.basename(video_path) ) except Exception as e: return JSONResponse( status_code=500, content={"error": str(e)} ) @app.post("/api/analyze-image") async def api_analyze_image( analysis_type: str = Form("caption"), image: UploadFile = File(...) ): try: # Save uploaded image image_path = f"uploads/{uuid.uuid4()}.png" os.makedirs("uploads", exist_ok=True) with open(image_path, "wb") as f: f.write(await image.read()) # Analyze image request = ImageAnalysisRequest( image_url=image_path, analysis_type=analysis_type ) results = analyze_image(request) return results except Exception as e: return JSONResponse( status_code=500, content={"error": str(e)} ) return app # Export FastAPI app for serverless deployment app = build_fastapi() # ------------ MAIN ENTRY POINT ------------ if __name__ == "__main__": # Create directories os.makedirs("outputs", exist_ok=True) os.makedirs("uploads", exist_ok=True) # Launch Gradio interface demo = create_gradio_blocks() demo.launch()