import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from diffusers import StableDiffusionPipeline from sentence_transformers import SentenceTransformer, util import torch import contextlib # --- Load models --- device = "cuda" if torch.cuda.is_available() else "cpu" # Text-to-text model text_model_name = "google/flan-t5-large" text_tokenizer = AutoTokenizer.from_pretrained(text_model_name) text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_name).to(device) # Text-to-image model image_model_id = "runwayml/stable-diffusion-v1-5" image_pipe = StableDiffusionPipeline.from_pretrained( image_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, safety_checker=None # Optional for debugging ) image_pipe = image_pipe.to(device) # Sentence similarity model embedder = SentenceTransformer('all-MiniLM-L6-v2') # Image-like trigger phrases image_triggers = [ "generate an image of", "draw a", "create a picture of", "show me a", "visualize", "render", "sketch", ] # --- Core logic --- def multimodal_agent(prompt): # Step 1: Semantic similarity to image triggers prompt_embedding = embedder.encode(prompt, convert_to_tensor=True) trigger_embeddings = embedder.encode(image_triggers, convert_to_tensor=True) cosine_scores = util.pytorch_cos_sim(prompt_embedding, trigger_embeddings) max_score = torch.max(cosine_scores).item() # Step 2: Decision branch if max_score > 0.65: # Generate image with torch.autocast("cuda") if device == "cuda" else contextlib.nullcontext(): image = image_pipe(prompt).images[0] return None, image else: # Generate text inputs = text_tokenizer(prompt, return_tensors="pt").to(device) outputs = text_model.generate(**inputs, max_new_tokens=100) text = text_tokenizer.decode(outputs[0], skip_special_tokens=True) return text, None # --- UI --- with gr.Blocks() as demo: gr.Markdown("# 🤖 Smart Multimodal AI Agent\nGive a prompt — It decides text vs image automatically!") input_box = gr.Textbox(label="Enter your prompt") output_text = gr.Textbox(label="Text Output") output_image = gr.Image(label="Image Output") btn = gr.Button("Generate") btn.click(multimodal_agent, inputs=input_box, outputs=[output_text, output_image]) demo.launch()