|
|
import os |
|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import random |
|
|
import time |
|
|
import uuid |
|
|
import json |
|
|
from transformers import ( |
|
|
pipeline, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
AutoModelForSeq2SeqLM, |
|
|
BlipProcessor, |
|
|
BlipForConditionalGeneration, |
|
|
AutoImageProcessor, |
|
|
StableDiffusionPipeline |
|
|
) |
|
|
from diffusers import DiffusionPipeline, StableVideoDiffusionPipeline |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Union, Any |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" |
|
|
os.environ["HF_HOME"] = "/tmp/hf_home" |
|
|
|
|
|
|
|
|
class ContentIdeaRequest(BaseModel): |
|
|
prompt: str = Field(..., description="Initial prompt for content idea generation") |
|
|
num_ideas: int = Field(3, description="Number of ideas to generate", ge=1, le=5) |
|
|
creativity: float = Field(0.7, description="Creativity level (0.0-1.0)", ge=0.0, le=1.0) |
|
|
|
|
|
class ContentIdeaResponse(BaseModel): |
|
|
ideas: List[str] = Field(..., description="Generated content ideas") |
|
|
processing_time: float = Field(..., description="Processing time in seconds") |
|
|
model_used: str = Field(..., description="Model used for generation") |
|
|
|
|
|
class TextToImageRequest(BaseModel): |
|
|
prompt: str = Field(..., description="Text prompt for image generation") |
|
|
negative_prompt: Optional[str] = Field(None, description="Negative prompt for image generation") |
|
|
width: int = Field(512, description="Image width") |
|
|
height: int = Field(512, description="Image height") |
|
|
num_inference_steps: int = Field(30, description="Number of inference steps") |
|
|
guidance_scale: float = Field(7.5, description="Guidance scale") |
|
|
|
|
|
class ImageToVideoRequest(BaseModel): |
|
|
image_url: str = Field(..., description="URL of the image to convert to video") |
|
|
motion_strength: float = Field(0.5, description="Motion strength for video generation", ge=0.0, le=1.0) |
|
|
num_frames: int = Field(16, description="Number of frames for the video") |
|
|
|
|
|
class ImageAnalysisRequest(BaseModel): |
|
|
image_url: str = Field(..., description="URL of the image to analyze") |
|
|
analysis_type: str = Field("caption", description="Type of analysis: caption, objects, or detailed") |
|
|
|
|
|
|
|
|
MODEL_CACHE = {} |
|
|
|
|
|
def get_model(model_name, model_class, tokenizer_class=None, processor_class=None): |
|
|
"""Load a model from cache or download it""" |
|
|
if model_name not in MODEL_CACHE: |
|
|
try: |
|
|
logger.info(f"Loading model: {model_name}") |
|
|
start_time = time.time() |
|
|
|
|
|
if processor_class: |
|
|
processor = processor_class.from_pretrained(model_name) |
|
|
model = model_class.from_pretrained(model_name) |
|
|
MODEL_CACHE[model_name] = {"model": model, "processor": processor} |
|
|
elif tokenizer_class: |
|
|
tokenizer = tokenizer_class.from_pretrained(model_name) |
|
|
model = model_class.from_pretrained(model_name, torch_dtype=torch.float16) |
|
|
MODEL_CACHE[model_name] = {"model": model, "tokenizer": tokenizer} |
|
|
else: |
|
|
model = model_class.from_pretrained(model_name) |
|
|
MODEL_CACHE[model_name] = {"model": model} |
|
|
|
|
|
logger.info(f"Model {model_name} loaded in {time.time() - start_time:.2f} seconds") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model {model_name}: {str(e)}") |
|
|
raise |
|
|
|
|
|
return MODEL_CACHE[model_name] |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"text_generator": "distilgpt2", |
|
|
"text_to_image": "stabilityai/stable-diffusion-2-base", |
|
|
"image_to_video": "stabilityai/stable-video-diffusion-img2vid-xt", |
|
|
"image_captioning": "Salesforce/blip-image-captioning-base", |
|
|
"text_summarization": "facebook/bart-large-cnn", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def generate_content_ideas(request: ContentIdeaRequest) -> ContentIdeaResponse: |
|
|
"""Generate content ideas based on a prompt""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
model_name = MODELS["text_generator"] |
|
|
model_data = get_model(model_name, AutoModelForCausalLM, AutoTokenizer) |
|
|
|
|
|
|
|
|
temperature = 0.5 + (request.creativity * 0.5) |
|
|
max_length = 100 + int(request.creativity * 100) |
|
|
|
|
|
generator = pipeline( |
|
|
"text-generation", |
|
|
model=model_data["model"], |
|
|
tokenizer=model_data["tokenizer"], |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
|
|
|
ideas = [] |
|
|
for _ in range(request.num_ideas): |
|
|
prompt = f"Generate a creative content idea based on: {request.prompt}\nContent idea:" |
|
|
result = generator( |
|
|
prompt, |
|
|
max_length=max_length, |
|
|
temperature=temperature, |
|
|
num_return_sequences=1, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = result[0]["generated_text"] |
|
|
idea = generated_text.split("Content idea:")[1].strip() |
|
|
ideas.append(idea) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
return ContentIdeaResponse( |
|
|
ideas=ideas, |
|
|
processing_time=processing_time, |
|
|
model_used=model_name |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating content ideas: {str(e)}") |
|
|
raise gr.Error(f"Failed to generate content ideas: {str(e)}") |
|
|
|
|
|
def text_to_image(request: TextToImageRequest) -> str: |
|
|
"""Convert text prompt to image""" |
|
|
try: |
|
|
model_name = MODELS["text_to_image"] |
|
|
|
|
|
|
|
|
if model_name not in MODEL_CACHE: |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
MODEL_CACHE[model_name] = {"pipeline": pipe} |
|
|
|
|
|
pipe = MODEL_CACHE[model_name]["pipeline"] |
|
|
|
|
|
|
|
|
image = pipe( |
|
|
prompt=request.prompt, |
|
|
negative_prompt=request.negative_prompt, |
|
|
width=request.width, |
|
|
height=request.height, |
|
|
num_inference_steps=request.num_inference_steps, |
|
|
guidance_scale=request.guidance_scale |
|
|
).images[0] |
|
|
|
|
|
|
|
|
output_dir = "outputs" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
filename = f"{output_dir}/image_{uuid.uuid4()}.png" |
|
|
image.save(filename) |
|
|
|
|
|
return filename |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in text-to-image conversion: {str(e)}") |
|
|
raise gr.Error(f"Failed to generate image: {str(e)}") |
|
|
|
|
|
def image_to_video(request: ImageToVideoRequest) -> str: |
|
|
"""Convert image to video with motion""" |
|
|
try: |
|
|
image_path = request.image_url |
|
|
if not os.path.exists(image_path): |
|
|
raise gr.Error(f"Image file not found: {image_path}") |
|
|
|
|
|
|
|
|
image = Image.open(image_path) |
|
|
|
|
|
|
|
|
model_name = MODELS["image_to_video"] |
|
|
if model_name not in MODEL_CACHE: |
|
|
pipe = StableVideoDiffusionPipeline.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
variant="fp16" |
|
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
MODEL_CACHE[model_name] = {"pipeline": pipe} |
|
|
|
|
|
pipe = MODEL_CACHE[model_name]["pipeline"] |
|
|
|
|
|
|
|
|
result = pipe( |
|
|
image, |
|
|
motion_bucket_id=int(request.motion_strength * 100), |
|
|
num_frames=request.num_frames |
|
|
).frames[0] |
|
|
|
|
|
|
|
|
output_dir = "outputs" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
filename = f"{output_dir}/video_{uuid.uuid4()}.gif" |
|
|
result.save( |
|
|
filename, |
|
|
save_all=True, |
|
|
append_images=result[1:], |
|
|
optimize=True, |
|
|
duration=100, |
|
|
loop=0 |
|
|
) |
|
|
|
|
|
return filename |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in image-to-video conversion: {str(e)}") |
|
|
raise gr.Error(f"Failed to generate video: {str(e)}") |
|
|
|
|
|
def analyze_image(request: ImageAnalysisRequest) -> Dict[str, Any]: |
|
|
"""Analyze image content""" |
|
|
try: |
|
|
image_path = request.image_url |
|
|
if not os.path.exists(image_path): |
|
|
raise gr.Error(f"Image file not found: {image_path}") |
|
|
|
|
|
|
|
|
model_name = MODELS["image_captioning"] |
|
|
if model_name not in MODEL_CACHE: |
|
|
processor = BlipProcessor.from_pretrained(model_name) |
|
|
model = BlipForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
MODEL_CACHE[model_name] = { |
|
|
"processor": processor, |
|
|
"model": model |
|
|
} |
|
|
|
|
|
processor = MODEL_CACHE[model_name]["processor"] |
|
|
model = MODEL_CACHE[model_name]["model"] |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
inputs = processor(image, return_tensors="pt") |
|
|
|
|
|
|
|
|
out = model.generate(**inputs) |
|
|
caption = processor.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if request.analysis_type == "caption": |
|
|
return {"caption": caption} |
|
|
elif request.analysis_type == "objects": |
|
|
|
|
|
keywords = caption.replace(",", "").replace(".", "").split() |
|
|
objects = [word for word in keywords if len(word) > 3] |
|
|
return {"caption": caption, "objects": objects} |
|
|
elif request.analysis_type == "detailed": |
|
|
|
|
|
enhanced_caption = f"The image shows {caption}. This appears to be a {caption.split()[0]} scene." |
|
|
return { |
|
|
"caption": caption, |
|
|
"detailed_description": enhanced_caption, |
|
|
"analysis_type": "basic visual elements" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in image analysis: {str(e)}") |
|
|
raise gr.Error(f"Failed to analyze image: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
def create_gradio_blocks(): |
|
|
"""Create Gradio interface with tab-based organization""" |
|
|
|
|
|
with gr.Blocks(title="Multi-Modal Content API") as demo: |
|
|
gr.Markdown("# 🎨 Multi-Modal Content Generation API") |
|
|
gr.Markdown("Generate content ideas, convert text to images, images to videos, and analyze images.") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("Content Idea Generator"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
idea_prompt = gr.Textbox(label="Prompt for Content Ideas", placeholder="Enter a starting point for content ideas...") |
|
|
num_ideas = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of Ideas") |
|
|
creativity = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Creativity Level") |
|
|
idea_generate_btn = gr.Button("Generate Content Ideas") |
|
|
|
|
|
with gr.Column(): |
|
|
idea_output = gr.JSON(label="Generated Content Ideas") |
|
|
|
|
|
|
|
|
with gr.TabItem("Text to Image"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
img_prompt = gr.Textbox(label="Image Prompt", placeholder="Describe the image you want to create...") |
|
|
img_negative_prompt = gr.Textbox(label="Negative Prompt (Optional)", placeholder="What to exclude from the image...") |
|
|
|
|
|
with gr.Row(): |
|
|
img_width = gr.Slider(minimum=256, maximum=768, value=512, step=64, label="Width") |
|
|
img_height = gr.Slider(minimum=256, maximum=768, value=512, step=64, label="Height") |
|
|
|
|
|
with gr.Row(): |
|
|
img_steps = gr.Slider(minimum=10, maximum=50, value=30, step=1, label="Inference Steps") |
|
|
img_guidance = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale") |
|
|
|
|
|
img_generate_btn = gr.Button("Generate Image") |
|
|
|
|
|
with gr.Column(): |
|
|
img_output = gr.Image(label="Generated Image") |
|
|
|
|
|
|
|
|
with gr.TabItem("Image to Video"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
vid_image = gr.Image(label="Upload Image", type="filepath") |
|
|
vid_motion = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Motion Strength") |
|
|
vid_frames = gr.Slider(minimum=8, maximum=24, value=16, step=8, label="Number of Frames") |
|
|
vid_generate_btn = gr.Button("Generate Video") |
|
|
|
|
|
with gr.Column(): |
|
|
vid_output = gr.Video(label="Generated Video") |
|
|
|
|
|
|
|
|
with gr.TabItem("Image Analysis"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
analysis_image = gr.Image(label="Upload Image for Analysis", type="filepath") |
|
|
analysis_type = gr.Radio( |
|
|
["caption", "objects", "detailed"], |
|
|
label="Analysis Type", |
|
|
value="caption" |
|
|
) |
|
|
analysis_btn = gr.Button("Analyze Image") |
|
|
|
|
|
with gr.Column(): |
|
|
analysis_output = gr.JSON(label="Analysis Results") |
|
|
|
|
|
|
|
|
idea_generate_btn.click( |
|
|
fn=lambda prompt, num, creativity: generate_content_ideas( |
|
|
ContentIdeaRequest(prompt=prompt, num_ideas=num, creativity=creativity) |
|
|
), |
|
|
inputs=[idea_prompt, num_ideas, creativity], |
|
|
outputs=idea_output |
|
|
) |
|
|
|
|
|
img_generate_btn.click( |
|
|
fn=lambda prompt, neg_prompt, width, height, steps, guidance: text_to_image( |
|
|
TextToImageRequest( |
|
|
prompt=prompt, |
|
|
negative_prompt=neg_prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=guidance |
|
|
) |
|
|
), |
|
|
inputs=[img_prompt, img_negative_prompt, img_width, img_height, img_steps, img_guidance], |
|
|
outputs=img_output |
|
|
) |
|
|
|
|
|
vid_generate_btn.click( |
|
|
fn=lambda image, motion, frames: image_to_video( |
|
|
ImageToVideoRequest( |
|
|
image_url=image, |
|
|
motion_strength=motion, |
|
|
num_frames=frames |
|
|
) |
|
|
), |
|
|
inputs=[vid_image, vid_motion, vid_frames], |
|
|
outputs=vid_output |
|
|
) |
|
|
|
|
|
analysis_btn.click( |
|
|
fn=lambda image, analysis_type: analyze_image( |
|
|
ImageAnalysisRequest( |
|
|
image_url=image, |
|
|
analysis_type=analysis_type |
|
|
) |
|
|
), |
|
|
inputs=[analysis_image, analysis_type], |
|
|
outputs=analysis_output |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
def build_fastapi(): |
|
|
"""Create FastAPI endpoints for programmatic access""" |
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
app = FastAPI( |
|
|
title="Multi-Modal Content API", |
|
|
description="API for content generation, image creation, video generation, and image analysis", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"message": "Welcome to the Multi-Modal Content API"} |
|
|
|
|
|
@app.post("/api/generate-ideas", response_model=ContentIdeaResponse) |
|
|
def api_generate_ideas(request: ContentIdeaRequest): |
|
|
return generate_content_ideas(request) |
|
|
|
|
|
@app.post("/api/text-to-image") |
|
|
def api_text_to_image(request: TextToImageRequest): |
|
|
try: |
|
|
image_path = text_to_image(request) |
|
|
return FileResponse( |
|
|
path=image_path, |
|
|
media_type="image/png", |
|
|
filename=os.path.basename(image_path) |
|
|
) |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": str(e)} |
|
|
) |
|
|
|
|
|
@app.post("/api/image-to-video") |
|
|
async def api_image_to_video( |
|
|
motion_strength: float = Form(0.5), |
|
|
num_frames: int = Form(16), |
|
|
image: UploadFile = File(...) |
|
|
): |
|
|
try: |
|
|
|
|
|
image_path = f"uploads/{uuid.uuid4()}.png" |
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
|
|
|
with open(image_path, "wb") as f: |
|
|
f.write(await image.read()) |
|
|
|
|
|
|
|
|
request = ImageToVideoRequest( |
|
|
image_url=image_path, |
|
|
motion_strength=motion_strength, |
|
|
num_frames=num_frames |
|
|
) |
|
|
|
|
|
video_path = image_to_video(request) |
|
|
|
|
|
return FileResponse( |
|
|
path=video_path, |
|
|
media_type="image/gif", |
|
|
filename=os.path.basename(video_path) |
|
|
) |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": str(e)} |
|
|
) |
|
|
|
|
|
@app.post("/api/analyze-image") |
|
|
async def api_analyze_image( |
|
|
analysis_type: str = Form("caption"), |
|
|
image: UploadFile = File(...) |
|
|
): |
|
|
try: |
|
|
|
|
|
image_path = f"uploads/{uuid.uuid4()}.png" |
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
|
|
|
with open(image_path, "wb") as f: |
|
|
f.write(await image.read()) |
|
|
|
|
|
|
|
|
request = ImageAnalysisRequest( |
|
|
image_url=image_path, |
|
|
analysis_type=analysis_type |
|
|
) |
|
|
|
|
|
results = analyze_image(request) |
|
|
return results |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"error": str(e)} |
|
|
) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
app = build_fastapi() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
|
|
|
|
|
|
demo = create_gradio_blocks() |
|
|
demo.launch() |