igardner commited on
Commit
7955305
Β·
1 Parent(s): c85251a

Changes to be committed:

Browse files
Files changed (1) hide show
  1. app.py +371 -88
app.py CHANGED
@@ -1,91 +1,374 @@
1
  import gradio as gr
2
- import torch
3
- from PIL import Image
4
- import numpy as np
5
- from transformers import pipeline
6
- from diffusers import AutoPipelineForInpainting
7
- import requests
8
- from io import BytesIO
9
-
10
- # ----------------------------
11
- # 1. Load models (lazy init for HF Spaces)
12
- # ----------------------------
13
- detector = None
14
- inpainter = None
15
-
16
- def load_models():
17
- global detector, inpainter
18
- if detector is None:
19
- print("Loading Grounding DINO...")
20
- detector = pipeline(
21
- "object-detection",
22
- model="facebook/grounding-dino-base",
23
- device="cpu"
24
- )
25
- if inpainter is None:
26
- print("Loading Stable Diffusion 3 Inpainting...")
27
- inpainter = AutoPipelineForInpainting.from_pretrained(
28
- "stabilityai/stable-diffusion-3-medium",
29
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
- safety_checker=None,
31
- )
32
- if torch.cuda.is_available():
33
- inpainter = inpainter.to("cuda")
34
-
35
- # ----------------------------
36
- # 2. Core logic
37
- # ----------------------------
38
- def remove_horse(image: Image.Image) -> Image.Image:
39
- load_models()
40
-
41
- # Detect horses
42
- results = detector(image, candidate_labels="horse")
43
- horse_detections = [r for r in results if r['label'] == 'horse' and r['score'] > 0.3]
44
-
45
- if not horse_detections:
46
- return image # No horse? Return original.
47
-
48
- # Create mask (white on black)
49
- mask = np.zeros(image.size[::-1], dtype=np.uint8)
50
- for det in horse_detections:
51
- box = det['box']
52
- x0, y0 = int(box['xmin']), int(box['ymin'])
53
- x1, y1 = int(box['xmax']), int(box['ymax'])
54
- mask[y0:y1, x0:x1] = 255
55
-
56
- mask_img = Image.fromarray(mask).convert("L")
57
-
58
- # Inpaint
59
- prompt = "a natural scene with no horse, seamless background, photorealistic"
60
- edited = inpainter(
61
- prompt=prompt,
62
- image=image,
63
- mask_image=mask_img,
64
- num_inference_steps=28,
65
- strength=0.99,
66
- guidance_scale=7.0
67
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- return edited
70
-
71
- # ----------------------------
72
- # 3. Gradio UI
73
- # ----------------------------
74
- with gr.Blocks(title="NO HORSE") as demo:
75
- gr.Markdown("# 🐴 NO HORSE")
76
- gr.Markdown("Upload an image. Horses will be **erased from reality**.")
77
- with gr.Row():
78
- input_img = gr.Image(type="pil", label="Input (with horse)")
79
- output_img = gr.Image(type="pil", label="Output (no horse)")
80
- btn = gr.Button("PURGE HORSE")
81
- btn.click(fn=remove_horse, inputs=input_img, outputs=output_img)
82
- gr.Examples(
83
- examples=["examples/horse1.jpg", "examples/horse2.jpg"],
84
- inputs=input_img
85
- )
86
-
87
- # ----------------------------
88
- # Launch
89
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  if __name__ == "__main__":
91
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ import logging
3
+ from contextlib import contextmanager
4
+ from typing import Optional
5
+ import re
6
+ import json
7
+
8
+ # === Internal Imports (your fixed backend) ===
9
+ from smolagents import Tool
10
+ from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey
11
+ from sqlalchemy.ext.declarative import declarative_base
12
+ from sqlalchemy.orm import sessionmaker, relationship
13
+ import discum
14
+ from datetime import datetime
15
+
16
+ # === Configuration ===
17
+ DB_PATH = "sqlite:///discord_bots.db"
18
+ Base = declarative_base()
19
+ engine = create_engine(DB_PATH, echo=False)
20
+ SessionLocal = sessionmaker(bind=engine)
21
+
22
+ # === Logging ===
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # === Models (same as before) ===
27
+ class Personality(Base):
28
+ __tablename__ = 'personalities'
29
+ id = Column(Integer, primary_key=True)
30
+ name = Column(String, nullable=False)
31
+ base_prompt = Column(Text)
32
+ created_at = Column(DateTime, default=datetime.utcnow)
33
+
34
+ class ChannelConfig(Base):
35
+ __tablename__ = 'channel_configs'
36
+ id = Column(Integer, primary_key=True)
37
+ channel_id = Column(String, nullable=False)
38
+ personality_id = Column(Integer, ForeignKey('personalities.id'))
39
+ context_analysis = Column(Text)
40
+ adapted_prompt = Column(Text)
41
+ created_at = Column(DateTime, default=datetime.utcnow)
42
+ personality = relationship("Personality")
43
+
44
+ Base.metadata.create_all(engine)
45
+
46
+ # === Global State ===
47
+ _discum_clients = {} # token_id (int) -> discum.Client
48
+
49
+ # === Utilities ===
50
+ @contextmanager
51
+ def get_db_session():
52
+ session = SessionLocal()
53
+ try:
54
+ yield session
55
+ session.commit()
56
+ except Exception as e:
57
+ session.rollback()
58
+ logger.exception("Database error")
59
+ raise
60
+ finally:
61
+ session.close()
62
+
63
+ def validate_discord_id(discord_id: str) -> bool:
64
+ return bool(re.fullmatch(r'\d+', discord_id))
65
+
66
+ # === Core Tool Functions (refactored for Gradio) ===
67
+ def create_personality(name: str, base_prompt: str):
68
+ if not name.strip() or not base_prompt.strip():
69
+ return gr.update(value="❌ Error: Name and base prompt cannot be empty.", visible=True), None
70
+ try:
71
+ with get_db_session() as session:
72
+ personality = Personality(name=name.strip(), base_prompt=base_prompt.strip())
73
+ session.add(personality)
74
+ session.flush()
75
+ pid = personality.id
76
+ return gr.update(value=f"βœ… Personality '{name}' created with ID: {pid}", visible=True), pid
77
+ except Exception as e:
78
+ logger.error(f"Create personality error: {e}")
79
+ return gr.update(value="❌ Failed to create personality.", visible=True), None
80
+
81
+ def configure_channel(channel_id: str, personality_id: str, context_analysis: str):
82
+ if not channel_id or not personality_id or not context_analysis.strip():
83
+ return gr.update(value="❌ All fields are required.", visible=True)
84
+ if not validate_discord_id(channel_id):
85
+ return gr.update(value="❌ Invalid Channel ID (must be numeric).", visible=True)
86
+ try:
87
+ pid = int(personality_id)
88
+ if pid <= 0:
89
+ raise ValueError("Invalid ID")
90
+ with get_db_session() as session:
91
+ personality = session.query(Personality).filter(Personality.id == pid).first()
92
+ if not personality:
93
+ return gr.update(value=f"❌ Personality ID {pid} not found.", visible=True)
94
+ config = ChannelConfig(
95
+ channel_id=channel_id,
96
+ personality_id=pid,
97
+ context_analysis=context_analysis.strip()
98
+ )
99
+ session.add(config)
100
+ return gr.update(value=f"βœ… Channel {channel_id} configured with personality {pid}", visible=True)
101
+ except ValueError:
102
+ return gr.update(value="❌ Personality ID must be a positive integer.", visible=True)
103
+ except Exception as e:
104
+ logger.error(f"Configure channel error: {e}")
105
+ return gr.update(value="❌ Failed to configure channel.", visible=True)
106
+
107
+ def generate_channel_prompt(channel_id: str, personality_id: str, llm_analysis: str):
108
+ if not channel_id or not personality_id or not llm_analysis.strip():
109
+ return gr.update(value="❌ All fields are required.", visible=True), ""
110
+ if not validate_discord_id(channel_id):
111
+ return gr.update(value="❌ Invalid Channel ID.", visible=True), ""
112
+ try:
113
+ pid = int(personality_id)
114
+ if pid <= 0:
115
+ raise ValueError
116
+ with get_db_session() as session:
117
+ personality = session.query(Personality).filter(Personality.id == pid).first()
118
+ if not personality:
119
+ return gr.update(value="❌ Personality not found.", visible=True), ""
120
+ prompt = f"{personality.base_prompt}\n\nChannel Context:\n{llm_analysis.strip()}"
121
+ config = session.query(ChannelConfig).filter(
122
+ ChannelConfig.channel_id == channel_id,
123
+ ChannelConfig.personality_id == pid
124
+ ).first()
125
+ if config:
126
+ config.adapted_prompt = prompt
127
+ else:
128
+ config = ChannelConfig(
129
+ channel_id=channel_id,
130
+ personality_id=pid,
131
+ adapted_prompt=prompt
132
+ )
133
+ session.add(config)
134
+ return gr.update(value="βœ… Prompt generated and saved.", visible=True), prompt
135
+ except ValueError:
136
+ return gr.update(value="❌ Personality ID must be a positive integer.", visible=True), ""
137
+ except Exception as e:
138
+ logger.error(f"Generate prompt error: {e}")
139
+ return gr.update(value="❌ Failed to generate prompt.", visible=True), ""
140
+
141
+ def get_channel_prompt(channel_id: str, personality_id: str):
142
+ if not channel_id or not personality_id:
143
+ return gr.update(value="❌ Both fields required.", visible=True), ""
144
+ if not validate_discord_id(channel_id):
145
+ return gr.update(value="❌ Invalid Channel ID.", visible=True), ""
146
+ try:
147
+ pid = int(personality_id)
148
+ if pid <= 0:
149
+ raise ValueError
150
+ with get_db_session() as session:
151
+ config = session.query(ChannelConfig).filter(
152
+ ChannelConfig.channel_id == channel_id,
153
+ ChannelConfig.personality_id == pid
154
+ ).first()
155
+ if config and config.adapted_prompt:
156
+ return gr.update(value="βœ… Prompt retrieved.", visible=True), config.adapted_prompt
157
+ return gr.update(value="ℹ️ No adapted prompt found.", visible=True), ""
158
+ except ValueError:
159
+ return gr.update(value="❌ Personality ID must be a positive integer.", visible=True), ""
160
+ except Exception as e:
161
+ logger.error(f"Get prompt error: {e}")
162
+ return gr.update(value="❌ Failed to retrieve prompt.", visible=True), ""
163
+
164
+ def initialize_bot(token_id: str, token: str):
165
+ if not token_id or not token.strip():
166
+ return gr.update(value="❌ Token ID and token are required.", visible=True)
167
+ try:
168
+ tid = int(token_id)
169
+ if tid <= 0:
170
+ raise ValueError("Invalid token ID")
171
+ if len(token) < 50:
172
+ return gr.update(value="❌ Token appears invalid (too short).", visible=True)
173
+ if tid in _discum_clients:
174
+ _discum_clients[tid].close()
175
+ client = discum.Client(token=token, log=False)
176
+ _discum_clients[tid] = client
177
+ return gr.update(value=f"βœ… Bot initialized with Token ID: {tid}", visible=True)
178
+ except ValueError:
179
+ return gr.update(value="❌ Token ID must be a positive integer.", visible=True)
180
+ except Exception as e:
181
+ logger.error(f"Bot init error: {e}")
182
+ return gr.update(value="❌ Failed to initialize bot. Check token.", visible=True)
183
+
184
+ def send_message(token_id: str, channel_id: str, message: str):
185
+ if not token_id or not channel_id or not message.strip():
186
+ return gr.update(value="❌ All fields required.", visible=True)
187
+ if not validate_discord_id(channel_id):
188
+ return gr.update(value="❌ Invalid Channel ID.", visible=True)
189
+ try:
190
+ tid = int(token_id)
191
+ if tid not in _discum_clients:
192
+ return gr.update(value="❌ Bot not initialized. Use 'Initialize Bot' first.", visible=True)
193
+ client = _discum_clients[tid]
194
+ client.sendMessage(channel_id, message.strip())
195
+ return gr.update(value="βœ… Message sent successfully!", visible=True)
196
+ except ValueError:
197
+ return gr.update(value="❌ Token ID must be an integer.", visible=True)
198
+ except Exception as e:
199
+ logger.error(f"Send message error: {e}")
200
+ return gr.update(value="❌ Failed to send message. Check bot permissions.", visible=True)
201
+
202
+ def get_channel_messages(token_id: str, channel_id: str, limit: int):
203
+ if not token_id or not channel_id:
204
+ return gr.update(value="❌ Token ID and Channel ID required.", visible=True), ""
205
+ if not validate_discord_id(channel_id):
206
+ return gr.update(value="❌ Invalid Channel ID.", visible=True), ""
207
+ try:
208
+ tid = int(token_id)
209
+ if tid not in _discum_clients:
210
+ return gr.update(value="❌ Bot not initialized.", visible=True), ""
211
+ limit = max(1, min(limit, 100))
212
+ client = _discum_clients[tid]
213
+ resp = client.getMessages(channel_id, num=limit)
214
+ if resp.status_code != 200:
215
+ return gr.update(value=f"❌ Discord API error: {resp.status_code}", visible=True), ""
216
+ messages = resp.json()
217
+ formatted = json.dumps(messages, indent=2)
218
+ return gr.update(value="βœ… Messages retrieved.", visible=True), formatted
219
+ except ValueError:
220
+ return gr.update(value="❌ Token ID must be an integer.", visible=True), ""
221
+ except Exception as e:
222
+ logger.error(f"Get messages error: {e}")
223
+ return gr.update(value="❌ Failed to fetch messages.", visible=True), ""
224
+
225
+ # === Gradio Interface ===
226
+ with gr.Blocks(title="Discord Personality Bot Manager") as demo:
227
+ gr.Markdown("# πŸ€– Discord Personality Bot Manager")
228
+ gr.Markdown("Manage AI personalities, configure channels, and interact with Discord via a secure web interface.")
229
 
230
+ with gr.Tabs():
231
+ # === Documentation Tab ===
232
+ with gr.Tab("πŸ“˜ Documentation"):
233
+ gr.Markdown("""
234
+ ## How to Use This Tool
235
+
236
+ ### Step 1: Initialize Your Bot
237
+ - Go to **Bot Management** β†’ **Initialize Bot**
238
+ - Provide a **Token ID** (any positive integer, e.g., `1`) and your **Discord token**
239
+ - ⚠️ **Never share your token!** It grants full account access.
240
+
241
+ ### Step 2: Create a Personality
242
+ - Go to **Personality Management** β†’ **Create Personality**
243
+ - Give it a name (e.g., "Sarcastic Helper") and a base prompt (e.g., "You are a witty assistant...")
244
+
245
+ ### Step 3: Configure a Channel
246
+ - Get your Discord **Channel ID** (enable Developer Mode in Discord β†’ right-click channel β†’ Copy ID)
247
+ - Use **Configure Channel** to link a personality to a channel with context analysis
248
+ - Or use **Generate Channel Prompt** to auto-create an adapted prompt
249
+
250
+ ### Step 4: Interact
251
+ - Use **Get Channel Prompt** to see the final prompt
252
+ - Use **Send Message** to post as your bot
253
+ - Use **Get Messages** to fetch recent channel history
254
+
255
+ ---
256
+ **Security Notes**:
257
+ - Tokens are stored in memory only (not saved to disk)
258
+ - Always close the browser tab when done
259
+ - Re-initialize if the bot disconnects
260
+ """)
261
+
262
+ # === Personality Management ===
263
+ with gr.Tab("🧠 Personality Management"):
264
+ with gr.Row():
265
+ with gr.Column():
266
+ gr.Markdown("### Create New Personality")
267
+ name_input = gr.Textbox(label="Personality Name", placeholder="e.g., Helpful Assistant")
268
+ prompt_input = gr.Textbox(label="Base Prompt", lines=3, placeholder="Describe the personality...")
269
+ create_btn = gr.Button("Create Personality")
270
+ create_status = gr.Textbox(label="Status", interactive=False, visible=False)
271
+ created_id = gr.Number(label="Created Personality ID", visible=False)
272
+
273
+ with gr.Column():
274
+ gr.Markdown("### Generate Channel-Specific Prompt")
275
+ gen_ch_id = gr.Textbox(label="Channel ID")
276
+ gen_pid = gr.Textbox(label="Personality ID")
277
+ llm_analysis = gr.Textbox(label="LLM Context Analysis", lines=3, placeholder="e.g., This channel discusses Python programming...")
278
+ gen_btn = gr.Button("Generate & Save Prompt")
279
+ gen_status = gr.Textbox(label="Status", interactive=False, visible=False)
280
+ gen_output = gr.Textbox(label="Generated Prompt", lines=5, interactive=False)
281
+
282
+ create_btn.click(
283
+ create_personality,
284
+ inputs=[name_input, prompt_input],
285
+ outputs=[create_status, created_id]
286
+ )
287
+ gen_btn.click(
288
+ generate_channel_prompt,
289
+ inputs=[gen_ch_id, gen_pid, llm_analysis],
290
+ outputs=[gen_status, gen_output]
291
+ )
292
+
293
+ # === Channel Configuration ===
294
+ with gr.Tab("βš™οΈ Channel Configuration"):
295
+ with gr.Row():
296
+ with gr.Column():
297
+ gr.Markdown("### Configure Channel with Personality")
298
+ conf_ch_id = gr.Textbox(label="Channel ID")
299
+ conf_pid = gr.Textbox(label="Personality ID")
300
+ context_analysis = gr.Textbox(label="Context Analysis", lines=3, placeholder="Describe the channel's purpose...")
301
+ conf_btn = gr.Button("Configure Channel")
302
+ conf_status = gr.Textbox(label="Status", interactive=False, visible=False)
303
+
304
+ with gr.Column():
305
+ gr.Markdown("### Retrieve Channel Prompt")
306
+ get_ch_id = gr.Textbox(label="Channel ID")
307
+ get_pid = gr.Textbox(label="Personality ID")
308
+ get_btn = gr.Button("Get Prompt")
309
+ get_status = gr.Textbox(label="Status", interactive=False, visible=False)
310
+ get_output = gr.Textbox(label="Adapted Prompt", lines=5, interactive=False)
311
+
312
+ conf_btn.click(
313
+ configure_channel,
314
+ inputs=[conf_ch_id, conf_pid, context_analysis],
315
+ outputs=conf_status
316
+ )
317
+ get_btn.click(
318
+ get_channel_prompt,
319
+ inputs=[get_ch_id, get_pid],
320
+ outputs=[get_status, get_output]
321
+ )
322
+
323
+ # === Bot Management ===
324
+ with gr.Tab("πŸ€– Bot Management"):
325
+ with gr.Row():
326
+ with gr.Column():
327
+ gr.Markdown("### Initialize Discord Bot")
328
+ token_id_input = gr.Textbox(label="Token ID (e.g., 1)", placeholder="Positive integer")
329
+ token_input = gr.Textbox(label="Discord Token", type="password", placeholder="Paste your token here")
330
+ init_btn = gr.Button("Initialize Bot")
331
+ init_status = gr.Textbox(label="Status", interactive=False, visible=False)
332
+
333
+ with gr.Column():
334
+ gr.Markdown("### Send Message to Channel")
335
+ send_token_id = gr.Textbox(label="Token ID")
336
+ send_ch_id = gr.Textbox(label="Channel ID")
337
+ message_input = gr.Textbox(label="Message", lines=3)
338
+ send_btn = gr.Button("Send Message")
339
+ send_status = gr.Textbox(label="Status", interactive=False, visible=False)
340
+
341
+ with gr.Row():
342
+ with gr.Column():
343
+ gr.Markdown("### Get Channel Messages")
344
+ msg_token_id = gr.Textbox(label="Token ID")
345
+ msg_ch_id = gr.Textbox(label="Channel ID")
346
+ limit_slider = gr.Slider(1, 100, value=10, label="Number of Messages")
347
+ get_msgs_btn = gr.Button("Fetch Messages")
348
+ msgs_status = gr.Textbox(label="Status", interactive=False, visible=False)
349
+ messages_output = gr.Textbox(label="Messages (JSON)", lines=10, interactive=False)
350
+
351
+ init_btn.click(
352
+ initialize_bot,
353
+ inputs=[token_id_input, token_input],
354
+ outputs=init_status
355
+ )
356
+ send_btn.click(
357
+ send_message,
358
+ inputs=[send_token_id, send_ch_id, message_input],
359
+ outputs=send_status
360
+ )
361
+ get_msgs_btn.click(
362
+ get_channel_messages,
363
+ inputs=[msg_token_id, msg_ch_id, limit_slider],
364
+ outputs=[msgs_status, messages_output]
365
+ )
366
+
367
+ # Launch the app
368
  if __name__ == "__main__":
369
+ demo.launch(
370
+ server_name="0.0.0.0", # Allow external access
371
+ server_port=7860,
372
+ show_error=True,
373
+ share=False # Set to True for public link (use cautiously!)
374
+ )