| import gradio as gr |
| import os |
| import base64 |
| import pandas as pd |
| from PIL import Image |
| from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool, Tool |
| from typing import Optional |
| import requests |
| from io import BytesIO |
| import re |
| from pathlib import Path |
| import openai |
| from openai import OpenAI |
| import pdfplumber |
| import numpy as np |
|
|
|
|
| |
| def is_image_extension(filename: str) -> bool: |
| IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'} |
| ext = os.path.splitext(filename)[1].lower() |
| return ext in IMAGE_EXTS |
|
|
| def load_file(path: str) -> list | dict: |
| """Based on the file extension, load the file into a suitable object.""" |
| |
| image = None |
| text = None |
| ext = Path(path).suffix.lower() |
|
|
| if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"): |
| image = Image.open(path).convert("RGB") |
| elif ext.endswith(".xlsx") or ext.endswith(".xls"): |
| text = pd.read_excel(path) |
| elif ext.endswith(".csv"): |
| text = pd.read_csv(path) |
| elif ext.endswith(".pdf"): |
| with pdfplumber.open(path) as pdf: |
| text = "\n".join(page.extract_text() for page in pdf.pages if page.extract_text()) |
| elif ext.endswith(".py") or ext.endswith(".txt"): |
| with open(path, 'r') as f: |
| text = f.read() |
| |
| if image is not None: |
| return [image] |
| elif ext.endswith(".mp3") or ext.endswith(".wav"): |
| return {"audio path": path} |
| else: |
| return {"raw document text": text, "file path": path} |
| |
| def check_format(answer: str | list, *args, **kwargs) -> list: |
| """Check if the answer is a list and not a nested list.""" |
| print("Checking format of the answer:", answer) |
| if isinstance(answer, list): |
| for item in answer: |
| if isinstance(item, list): |
| print("Nested list detected") |
| raise TypeError("Nested lists are not allowed in the final answer.") |
| print("Final answer is a list:") |
| return answer |
| elif isinstance(answer, str): |
| return [answer] |
| elif isinstance(answer, dict): |
| raise TypeError(f"Final answer must be a list, not a dict. Please check the answer format.") |
|
|
|
|
| |
| @tool |
| def download_images(image_urls: str) -> list: |
| """ |
| Download web images from the given comma‐separated URLs and return them in a list of PIL Images. |
| Args: |
| image_urls: comma‐separated list of URLs to download |
| Returns: |
| List of PIL.Image.Image objects wrapped by gr.Image |
| """ |
| urls = [u.strip() for u in image_urls.split(",") if u.strip()] |
| images = [] |
| for __, url in enumerate(urls, start=1): |
| try: |
| |
| resp = requests.get(url, timeout=10) |
| resp.raise_for_status() |
|
|
| |
| img = Image.open(BytesIO(resp.content)).convert("RGB") |
| images.append(img) |
|
|
| except Exception as e: |
| print(f"Failed to download from {url}: {e}") |
| |
| wrapped = [] |
| for img in images: |
| wrapped.append(gr.Image(value=img)) |
| return wrapped |
|
|
| @tool |
| def transcribe_audio(audio_path: str) -> str: |
| """ |
| Transcribe audio file using OpenAI Whisper API. |
| Args: |
| audio_path: path to the audio file to be transcribed. |
| Returns: |
| str : Transcription of the audio. |
| """ |
| client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) |
| with open(audio_path, "rb") as audio: |
| transcript = client.audio.transcriptions.create( |
| file=audio, |
| model="whisper-1", |
| response_format="text", |
| ) |
| print(transcript) |
| try: |
| return transcript |
| except Exception as e: |
| print(f"Error transcribing audio: {e}") |
|
|
| @tool |
| def generate_image(prompt: str, neg_prompt: str) -> Image.Image: |
| """ |
| Generate an image based on a text prompt using Flux Dev. |
| Args: |
| prompt: The text prompt to generate the image from. |
| neg_prompt: The negative prompt to avoid certain elements in the image. |
| Returns: |
| Image.Image: The generated image as a PIL Image object. |
| """ |
| client = OpenAI(base_url="https://api.studio.nebius.com/v1", |
| api_key=os.environ.get("NEBIUS_API_KEY"), |
| ) |
|
|
| completion = client.images.generate( |
| model="black-forest-labs/flux-dev", |
| prompt=prompt, |
| response_format="b64_json", |
| extra_body={ |
| "response_extension": "png", |
| "width": 1024, |
| "height": 1024, |
| "num_inference_steps": 30, |
| "seed": -1, |
| "negative_prompt": neg_prompt, |
| } |
| ) |
| |
| image_data = base64.b64decode(completion.to_dict()['data'][0]['b64_json']) |
| image = BytesIO(image_data) |
| image = Image.open(image).convert("RGB") |
|
|
| return gr.Image(value=image, label="Generated Image") |
|
|
| @tool |
| def generate_audio(prompt: str, duration: int) -> gr.Component: |
| """ |
| Generate audio from a text prompt using MusicGen. |
| Args: |
| prompt: The text prompt to generate the audio from. |
| duration: Duration of the generated audio in seconds. Max 30 seconds. |
| Returns: |
| gr.Component: The generated audio as a Gradio Audio component. |
| """ |
| client = Tool.from_space( |
| space_id="luke9705/MusicGen_custom", |
| token=os.environ.get('HF_TOKEN'), |
| name="Sound_Generator", |
| description="Generate music or sound effects from a text prompt using MusicGen." |
| ) |
| if duration > 30: |
| sound = client(prompt, 30) |
| else: |
| sound = client(prompt, duration) |
| |
| return gr.Audio(value=sound) |
|
|
| @tool |
| def generate_audio_from_sample(prompt: str, duration: int, sample_path: str = None) -> gr.Component: |
| """ |
| Generate audio from a text prompt + audio sample using MusicGen. |
| Args: |
| prompt: The text prompt to generate the audio from. |
| duration: Duration of the generated audio in seconds. Max 30 seconds. |
| sample_path: audio sample path to guide generation. |
| |
| Returns: |
| gr.Component: The generated audio as a Gradio Audio component. |
| """ |
| client = Tool.from_space( |
| space_id="luke9705/MusicGen_custom", |
| token=os.environ.get('HF_TOKEN'), |
| name="Sound_Generator", |
| description="Generate music or sound effects from a text prompt using MusicGen." |
| ) |
| if duration > 30: |
| sound = client(prompt, 30, sample_path) |
| else: |
| sound = client(prompt, duration, sample_path) |
| |
| return gr.Audio(value=sound) |
| |
|
|
| |
| class Agent: |
| def __init__(self, ): |
| |
| client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY")) |
| """client = OpenAIServerModel( |
| model_id="claude-opus-4-20250514", |
| api_base="https://api.anthropic.com/v1/", |
| api_key=os.environ["ANTHROPIC_API_KEY"], |
| )""" |
| self.agent = CodeAgent( |
| model=client, |
| tools=[DuckDuckGoSearchTool(max_results=5), |
| VisitWebpageTool(max_output_length=20000), |
| generate_image, |
| generate_audio_from_sample, |
| generate_audio, |
| download_images, |
| transcribe_audio], |
| additional_authorized_imports=["pandas", "PIL", "io"], |
| planning_interval=3, |
| max_steps=6, |
| stream_outputs=False, |
| final_answer_checks=[check_format] |
| ) |
| with open("system_prompt.txt", "r") as f: |
| system_prompt = f.read() |
| self.agent.prompt_templates["system_prompt"] = system_prompt |
| |
| |
|
|
| def __call__(self, message: str, |
| images: Optional[list[Image.Image]] = None, |
| files: Optional[str] = None, |
| conversation_history: Optional[dict] = None) -> str: |
| answer = self.agent.run(message, images = images, additional_args={"files": files, "conversation_history": conversation_history}) |
| return answer |
|
|
| |
| def respond(message: str, history : dict, web_search: bool = False): |
| |
| |
| print("history:", history) |
| text = message.get("text", "") |
| if not message.get("files") and not web_search: |
| print("No files received.") |
| message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", conversation_history=history) |
| elif not message.get("files") and web_search: |
| print("No files received + web search enabled.") |
| message = agent(text, conversation_history=history) |
| else: |
| files = message.get("files", []) |
| print(f"files received: {files}") |
| if is_image_extension(files[0]) and not web_search: |
| image = load_file(files[0]) |
| message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", images=image, conversation_history=history) |
| elif is_image_extension(files[0]) and web_search: |
| image = load_file(files[0]) |
| message = agent(text, images=image, conversation_history=history) |
| elif not web_search: |
| file = load_file(files[0]) |
| message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", files=file, conversation_history=history) |
| else: |
| file = load_file(files[0]) |
| message = agent(text, files=file, conversation_history=history) |
| |
| |
| print("Agent response:", message) |
| |
| return message |
|
|
| def initialize_agent(): |
| agent = Agent() |
| print("Agent initialized.") |
| return agent |
|
|
| |
|
|
| global agent |
| agent = initialize_agent() |
| demo = gr.ChatInterface( |
| fn=respond, |
| type='messages', |
| multimodal=True, |
| title='MultiAgent System for Screenplay Creation and Editing', |
| show_progress='full', |
| fill_height=True, |
| fill_width=True, |
| save_history=True, |
| autoscroll=True, |
| additional_inputs=[ |
| gr.Checkbox(value=False, label="Web Search", |
| info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.", |
| render=False), |
| ], |
| additional_inputs_accordion=gr.Accordion(label="Tools available: ", open=True, render=False) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|