luke9705 commited on
Commit
5a4500e
·
verified ·
1 Parent(s): f16cdd3

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +145 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import pandas as pd
4
+ from PIL import Image
5
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool
6
+ from typing import Optional
7
+ import requests
8
+ from io import BytesIO
9
+ import re
10
+ from pathlib import Path
11
+ import openai
12
+
13
+ ## utilty functions
14
+ def is_image_extension(filename: str) -> bool: # not used in the code, but useful to have
15
+ IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'}
16
+ ext = os.path.splitext(filename)[1].lower() # os.path.splitext(path) returns (root, ext)
17
+ return ext in IMAGE_EXTS
18
+
19
+ def load_file(path: list) -> dict:
20
+ """Based on the file extension, load the file into a suitable object."""
21
+
22
+ image = None
23
+ excel = None
24
+ csv = None
25
+ text = None
26
+ ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower()
27
+ print(f"ext: {ext}")
28
+
29
+ if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"):
30
+ image = Image.open(path).convert("RGB") # pillow object
31
+ elif ext.endswith(".xlsx") or ext.endswith(".xls"):
32
+ excel = pd.read_excel(path) # DataFrame
33
+ elif ext.endswith(".csv"):
34
+ csv = pd.read_csv(path) # DataFrame
35
+ elif ext.endswith(".py") or ext.endswith(".txt"):
36
+ with open(path, 'r') as f:
37
+ text = f.read() # plain text str
38
+ elif ext.endswith(".mp3") or ext.endswith(".wav"):
39
+ with open(path, 'wb') as f:
40
+ f.write("output.mp3") # binary data (leave it hardcoded for now)
41
+
42
+ return {"image" : image, "excel": excel, "csv": csv, "raw text": text}
43
+
44
+
45
+ ## tools definition
46
+ @tool
47
+ def download_images(image_urls: str) -> list:
48
+ """
49
+ Download web images from the given comma‐separated URLs and return them in a list of PIL Images.
50
+ Args:
51
+ image_urls: comma‐separated list of URLs to download
52
+ Returns:
53
+ List of PIL.Image.Image objects
54
+ """
55
+ urls = [u.strip() for u in image_urls.split(",") if u.strip()] # strip() removes whitespaces
56
+ images = []
57
+ for __, url in enumerate(urls, start=1): # enumerate seems not needed... keeping it for now
58
+ try:
59
+ # Fetch the image bytes
60
+ resp = requests.get(url, timeout=10)
61
+ resp.raise_for_status()
62
+
63
+ # Load into a PIL image
64
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
65
+ images.append(img)
66
+
67
+ except Exception as e:
68
+ print(f"Failed to download from {url}: {e}")
69
+ return images
70
+
71
+ @tool # since they gave us OpenAI API credits, we can keep using it
72
+ def transcribe_audio() -> str:
73
+ """
74
+ Transcribe audio file using OpenAI Whisper API.
75
+ The path to the audio file is hardcoded as "output.mp3". Don't need to pass it as an argument.
76
+ Returns:
77
+ str: Transcription of the audio.
78
+ """
79
+ client = openai.Client(api_key=os.getenv("OPEN_AI_API_KEY"))
80
+ with open("output.mp3", "rb") as audio: # to modify path because it is arriving from gradio
81
+ transcript = client.audio.transcriptions.create(
82
+ file=audio,
83
+ model="whisper-1",
84
+ response_format="text",
85
+ )
86
+ print(transcript)
87
+ try:
88
+ return transcript
89
+ except Exception as e:
90
+ print(f"Error transcribing audio: {e}")
91
+
92
+
93
+ ## agent definition
94
+ class Agent:
95
+ def __init__(self, ):
96
+ client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
97
+ self.agent = CodeAgent(
98
+ model=client,
99
+ tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), download_images, transcribe_audio],
100
+ additional_authorized_imports=["pandas", "PIL", "io"],
101
+ planning_interval=1,
102
+ max_steps=5,
103
+ )
104
+ #self.agent.prompt_templates["system_prompt"] = self.agent.prompt_templates["system_prompt"]
105
+ #print("System prompt:", self.agent.prompt_templates["system_prompt"])
106
+
107
+ def __call__(self, message: str, images: Optional[list[Image.Image]] = None, files: Optional[str] = None) -> str:
108
+ answer = self.agent.run(message, additional_args={"images": images ,"files": files})
109
+ return answer
110
+
111
+ ## gradio functions
112
+ def respond(message, history):
113
+
114
+ text = message.get("text", "")
115
+ if not message.get("files"):
116
+ print("No files received.")
117
+ message = agent(text)
118
+ else:
119
+ files = message.get("files", [])
120
+ print(f"files received: {files}")
121
+ file = load_file(files[0])
122
+ message = agent(text, files=file)
123
+
124
+ return message
125
+
126
+ def initialize_agent():
127
+ agent = Agent()
128
+ print("Agent initialized.")
129
+ return agent
130
+
131
+
132
+ with gr.Blocks() as demo:
133
+ global agent
134
+ agent = initialize_agent()
135
+ gr.ChatInterface(
136
+ fn=respond,
137
+ type='messages',
138
+ multimodal=True,
139
+ title='MultiAgent_System_for_Screenplay_Creation_and_Editing',
140
+ show_progress='full'
141
+ )
142
+
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ smolagents
3
+ openai