Or4cl3-2 commited on
Commit
039c729
Β·
verified Β·
1 Parent(s): d0a8e50

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +620 -0
app.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer, AutoModelForCausalLM,
7
+ TrainingArguments, Trainer,
8
+ DataCollatorForLanguageModeling,
9
+ pipeline
10
+ )
11
+ from datasets import Dataset
12
+ from huggingface_hub import HfApi, login
13
+ import spaces
14
+ from typing import Optional, Dict, Any, List, Tuple
15
+ import logging
16
+ import traceback
17
+ from datetime import datetime
18
+ import random
19
+ import re
20
+ from faker import Faker
21
+ import hashlib
22
+ import time
23
+ from collections import defaultdict
24
+ from functools import wraps
25
+
26
+ # Setup logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # ==================== RATE LIMITING ====================
34
+
35
+ class RateLimiter:
36
+ """Token bucket rate limiter"""
37
+
38
+ def __init__(self):
39
+ self.requests = defaultdict(list)
40
+ self.limits = {
41
+ 'synthetic_generation': {'calls': 10, 'period': 3600},
42
+ 'model_training': {'calls': 3, 'period': 3600},
43
+ 'model_inference': {'calls': 50, 'period': 3600},
44
+ }
45
+
46
+ def _get_user_id(self, request: gr.Request) -> str:
47
+ if request:
48
+ identifier = f"{request.client.host}_{request.headers.get('user-agent', '')}"
49
+ return hashlib.md5(identifier.encode()).hexdigest()
50
+ return "anonymous"
51
+
52
+ def _clean_old_requests(self, user_id: str, endpoint: str):
53
+ if user_id not in self.requests:
54
+ return
55
+ current_time = time.time()
56
+ period = self.limits[endpoint]['period']
57
+ self.requests[user_id] = [
58
+ req for req in self.requests[user_id]
59
+ if req['endpoint'] == endpoint and current_time - req['timestamp'] < period
60
+ ]
61
+
62
+ def check_rate_limit(self, user_id: str, endpoint: str) -> Tuple[bool, str]:
63
+ self._clean_old_requests(user_id, endpoint)
64
+ user_requests = [req for req in self.requests[user_id] if req['endpoint'] == endpoint]
65
+ limit = self.limits[endpoint]['calls']
66
+ period = self.limits[endpoint]['period']
67
+
68
+ if len(user_requests) >= limit:
69
+ time_until_reset = period - (time.time() - user_requests[0]['timestamp'])
70
+ minutes = int(time_until_reset / 60)
71
+ return False, f"⏱️ Rate limit exceeded! Please wait {minutes} minutes."
72
+
73
+ self.requests[user_id].append({'endpoint': endpoint, 'timestamp': time.time()})
74
+ remaining = limit - len(user_requests) - 1
75
+ return True, f"βœ… Request accepted ({remaining} remaining this hour)"
76
+
77
+ rate_limiter = RateLimiter()
78
+
79
+ def rate_limit(endpoint: str):
80
+ def decorator(func):
81
+ @wraps(func)
82
+ def wrapper(*args, **kwargs):
83
+ request = kwargs.get('request', None)
84
+ if request:
85
+ user_id = rate_limiter._get_user_id(request)
86
+ allowed, message = rate_limiter.check_rate_limit(user_id, endpoint)
87
+ if not allowed:
88
+ return f"🚫 {message}"
89
+ return func(*args, **kwargs)
90
+ return wrapper
91
+ return decorator
92
+
93
+ # ==================== AUTHENTICATION ====================
94
+
95
+ class AuthManager:
96
+ def __init__(self):
97
+ self.authenticated_tokens = {}
98
+ self.token_expiry = 86400
99
+
100
+ def validate_hf_token(self, token: str) -> Tuple[bool, str, Optional[str]]:
101
+ try:
102
+ if not token or not token.strip():
103
+ return False, "❌ Please provide a HuggingFace token", None
104
+
105
+ token_hash = hashlib.sha256(token.encode()).hexdigest()
106
+ if token_hash in self.authenticated_tokens:
107
+ cached = self.authenticated_tokens[token_hash]
108
+ if time.time() - cached['timestamp'] < self.token_expiry:
109
+ return True, f"βœ… Welcome back, {cached['username']}!", cached['username']
110
+
111
+ api = HfApi(token=token)
112
+ user_info = api.whoami()
113
+ username = user_info.get('name', 'Anonymous Architect')
114
+
115
+ self.authenticated_tokens[token_hash] = {
116
+ 'username': username,
117
+ 'timestamp': time.time()
118
+ }
119
+
120
+ return True, f"πŸŽ‰ Welcome, {username}!", username
121
+
122
+ except Exception as e:
123
+ return False, f"πŸ” Token validation failed: {str(e)}", None
124
+
125
+ auth_manager = AuthManager()
126
+
127
+ # ==================== ERROR HANDLING ====================
128
+
129
+ class ArchitechError(Exception):
130
+ pass
131
+
132
+ class DataGenerationError(ArchitechError):
133
+ pass
134
+
135
+ class ModelTrainingError(ArchitechError):
136
+ pass
137
+
138
+ class ModelInferenceError(ArchitechError):
139
+ pass
140
+
141
+ def handle_errors(error_type: str = "general"):
142
+ def decorator(func):
143
+ @wraps(func)
144
+ def wrapper(*args, **kwargs):
145
+ try:
146
+ return func(*args, **kwargs)
147
+ except torch.cuda.OutOfMemoryError:
148
+ return "πŸ”₯ **GPU Memory Overflow!** Try: smaller batch size, smaller model, or less data."
149
+ except PermissionError:
150
+ return "πŸ”’ **Permission Denied!** Check your HuggingFace token has WRITE access."
151
+ except ConnectionError:
152
+ return "🌐 **Connection Issue!** Can't reach HuggingFace. Check your network."
153
+ except ValueError as e:
154
+ return f"⚠️ **Invalid Input!** {str(e)}"
155
+ except (DataGenerationError, ModelTrainingError, ModelInferenceError) as e:
156
+ return f"πŸ”§ **Architech Error:** {str(e)}"
157
+ except Exception as e:
158
+ logger.error(f"Error in {func.__name__}: {traceback.format_exc()}")
159
+ return f"πŸ’₯ **Unexpected Error:** {str(e)}"
160
+ return wrapper
161
+ return decorator# ==================== SYNTHETIC DATA GENERATOR ====================
162
+
163
+ class SyntheticDataGenerator:
164
+ def __init__(self):
165
+ self.faker = Faker()
166
+ self.generation_templates = {
167
+ "conversational": [
168
+ "Human: {question}\nAssistant: {answer}",
169
+ "User: {question}\nBot: {answer}",
170
+ ],
171
+ "instruction": [
172
+ "### Instruction:\n{instruction}\n\n### Response:\n{response}",
173
+ ],
174
+ }
175
+
176
+ self.domain_knowledge = {
177
+ "technology": {
178
+ "topics": ["AI", "machine learning", "cloud computing"],
179
+ "concepts": ["algorithms", "APIs", "databases"],
180
+ "contexts": ["software development", "digital transformation"]
181
+ },
182
+ "healthcare": {
183
+ "topics": ["telemedicine", "diagnostics", "patient care"],
184
+ "concepts": ["treatments", "procedures"],
185
+ "contexts": ["clinical practice", "patient education"]
186
+ },
187
+ "finance": {
188
+ "topics": ["fintech", "investment", "risk management"],
189
+ "concepts": ["portfolios", "compliance"],
190
+ "contexts": ["financial advisory", "personal finance"]
191
+ },
192
+ "general": {
193
+ "topics": ["communication", "problem-solving"],
194
+ "concepts": ["strategies", "best practices"],
195
+ "contexts": ["daily life", "personal growth"]
196
+ }
197
+ }
198
+
199
+ def _generate_question(self, topic, concept, context):
200
+ templates = [
201
+ f"How does {concept} work in {context}?",
202
+ f"What are the benefits of {concept} for {topic}?",
203
+ f"Can you explain {concept}?",
204
+ f"What's the best approach to {concept}?"
205
+ ]
206
+ return random.choice(templates)
207
+
208
+ def _generate_answer(self, question, topic, concept):
209
+ templates = [
210
+ f"{concept} in {topic} works through strategic implementation. Key benefits include improved efficiency and better outcomes.",
211
+ f"Great question! {concept} is fundamental because it addresses core challenges. Best practices include planning and testing.",
212
+ f"When it comes to {concept}, consider scalability and performance. Success depends on proper implementation."
213
+ ]
214
+ return random.choice(templates)
215
+
216
+ def _generate_single_example(self, task_desc, domain_data, templates, complexity):
217
+ template = random.choice(templates)
218
+ topic = random.choice(domain_data["topics"])
219
+ concept = random.choice(domain_data["concepts"])
220
+ context = random.choice(domain_data["contexts"])
221
+
222
+ question = self._generate_question(topic, concept, context)
223
+ answer = self._generate_answer(question, topic, concept)
224
+
225
+ text = template.format(question=question, answer=answer)
226
+ return {"text": text}
227
+
228
+ @handle_errors("data_generation")
229
+ def generate_synthetic_dataset(
230
+ self,
231
+ task_description: str,
232
+ domain: str,
233
+ dataset_size: int = 100,
234
+ format_type: str = "conversational",
235
+ complexity: str = "medium",
236
+ progress=gr.Progress()
237
+ ) -> str:
238
+ if not task_description or len(task_description.strip()) < 10:
239
+ raise DataGenerationError("Task description too short! Need at least 10 characters.")
240
+
241
+ if dataset_size < 10 or dataset_size > 1000:
242
+ raise DataGenerationError("Dataset size must be between 10 and 1000.")
243
+
244
+ progress(0.1, f"🎯 Generating {dataset_size} examples...")
245
+
246
+ domain_data = self.domain_knowledge.get(domain, self.domain_knowledge["general"])
247
+ templates = self.generation_templates.get(format_type, self.generation_templates["conversational"])
248
+
249
+ synthetic_data = []
250
+ for i in range(dataset_size):
251
+ if i % 20 == 0:
252
+ progress(0.1 + (0.7 * i / dataset_size), f"πŸ“ Creating {i+1}/{dataset_size}...")
253
+
254
+ example = self._generate_single_example(task_description, domain_data, templates, complexity)
255
+ synthetic_data.append(example)
256
+
257
+ os.makedirs("./synthetic_datasets", exist_ok=True)
258
+ dataset_filename = f"synthetic_{domain}_{format_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
259
+ dataset_path = os.path.join("./synthetic_datasets", dataset_filename)
260
+
261
+ with open(dataset_path, 'w') as f:
262
+ json.dump(synthetic_data, f, indent=2)
263
+
264
+ preview = "\n\n---\n\n".join([ex["text"] for ex in synthetic_data[:3]])
265
+
266
+ return f"""🎊 **SYNTHETIC DATASET GENERATED!**
267
+
268
+ **Dataset Details:**
269
+ - πŸ“Š Size: {len(synthetic_data)} examples
270
+ - 🎯 Domain: {domain.title()}
271
+ - πŸ“ Format: {format_type.title()}
272
+ - πŸ’Ύ Saved as: `{dataset_filename}`
273
+
274
+ **Preview (First 3 Examples):**
275
+
276
+ {preview}
277
+
278
+ **Next Steps:** Use this in the 'Train Model' or 'Test Model' tabs!"""# ==================== MODEL INFERENCE ====================
279
+
280
+ class ModelInference:
281
+ def __init__(self):
282
+ self.loaded_models = {}
283
+
284
+ @handle_errors("inference")
285
+ def load_model(self, model_name: str, hf_token: str, progress=gr.Progress()) -> str:
286
+ progress(0.1, "πŸ” Locating your model...")
287
+
288
+ is_valid, message, username = auth_manager.validate_hf_token(hf_token)
289
+ if not is_valid:
290
+ raise ModelInferenceError(message)
291
+
292
+ full_model_name = f"{username}/{model_name}" if "/" not in model_name else model_name
293
+
294
+ progress(0.3, "πŸ“₯ Downloading model...")
295
+
296
+ try:
297
+ tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
298
+ model = AutoModelForCausalLM.from_pretrained(
299
+ full_model_name,
300
+ token=hf_token,
301
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
302
+ device_map="auto" if torch.cuda.is_available() else None
303
+ )
304
+
305
+ self.loaded_models[model_name] = {
306
+ 'model': model,
307
+ 'tokenizer': tokenizer,
308
+ 'pipeline': pipeline('text-generation', model=model, tokenizer=tokenizer)
309
+ }
310
+
311
+ progress(1.0, "βœ… Model loaded!")
312
+ return f"βœ… **Model Loaded Successfully!**\n\nModel: `{full_model_name}`\n\nReady for inference!"
313
+
314
+ except Exception as e:
315
+ raise ModelInferenceError(f"Failed to load model: {str(e)}")
316
+
317
+ @handle_errors("inference")
318
+ def generate_text(
319
+ self,
320
+ model_name: str,
321
+ prompt: str,
322
+ max_length: int = 100,
323
+ temperature: float = 0.7,
324
+ top_p: float = 0.9
325
+ ) -> str:
326
+ if model_name not in self.loaded_models:
327
+ raise ModelInferenceError("Model not loaded! Please load the model first.")
328
+
329
+ if not prompt or len(prompt.strip()) < 3:
330
+ raise ModelInferenceError("Prompt too short! Please provide at least 3 characters.")
331
+
332
+ pipe = self.loaded_models[model_name]['pipeline']
333
+
334
+ result = pipe(
335
+ prompt,
336
+ max_length=max_length,
337
+ temperature=temperature,
338
+ top_p=top_p,
339
+ do_sample=True,
340
+ num_return_sequences=1
341
+ )
342
+
343
+ generated_text = result[0]['generated_text']
344
+
345
+ return f"""**🎯 Generated Response:**
346
+
347
+ {generated_text}
348
+
349
+ ---
350
+ *Model: {model_name} | Length: {len(generated_text)} chars*"""
351
+
352
+ model_inference = ModelInference()# ==================== ARCHITECH AGENT ====================
353
+
354
+ class ArchitechAgent:
355
+ def __init__(self):
356
+ self.hf_api = HfApi()
357
+ self.synthetic_generator = SyntheticDataGenerator()
358
+ self.personality_responses = [
359
+ "🎯 Let's cook up some AI magic!",
360
+ "πŸš€ Time to turn your vision into reality!",
361
+ "🧠 Let's architect some brilliance!",
362
+ ]
363
+
364
+ def get_personality_response(self) -> str:
365
+ return random.choice(self.personality_responses)
366
+
367
+ @rate_limit('synthetic_generation')
368
+ @handle_errors("data_generation")
369
+ def generate_synthetic_dataset_wrapper(self, *args, **kwargs):
370
+ return self.synthetic_generator.generate_synthetic_dataset(*args, **kwargs)
371
+
372
+ @spaces.GPU
373
+ @rate_limit('model_training')
374
+ @handle_errors("training")
375
+ def train_custom_model(
376
+ self,
377
+ task_description: str,
378
+ training_data: str,
379
+ model_name: str,
380
+ hf_token: str,
381
+ base_model: str = "distilgpt2",
382
+ use_synthetic_data: bool = True,
383
+ synthetic_domain: str = "general",
384
+ synthetic_size: int = 100,
385
+ learning_rate: float = 2e-4,
386
+ num_epochs: int = 3,
387
+ batch_size: int = 2,
388
+ progress=gr.Progress()
389
+ ) -> str:
390
+
391
+ is_valid, message, username = auth_manager.validate_hf_token(hf_token)
392
+ if not is_valid:
393
+ raise ModelTrainingError(message)
394
+
395
+ progress(0.1, "🧠 Loading base model...")
396
+
397
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
398
+ if tokenizer.pad_token is None:
399
+ tokenizer.pad_token = tokenizer.eos_token
400
+
401
+ model = AutoModelForCausalLM.from_pretrained(
402
+ base_model,
403
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
404
+ device_map="auto" if torch.cuda.is_available() else None
405
+ )
406
+
407
+ if use_synthetic_data:
408
+ progress(0.2, "🎨 Generating synthetic data...")
409
+ result = self.synthetic_generator.generate_synthetic_dataset(
410
+ task_description, synthetic_domain, synthetic_size, "conversational", "medium", progress
411
+ )
412
+
413
+ dataset_files = [f for f in os.listdir("./synthetic_datasets") if f.endswith('.json')]
414
+ if not dataset_files:
415
+ raise ModelTrainingError("No synthetic dataset found!")
416
+
417
+ latest_dataset = max(dataset_files, key=lambda x: os.path.getctime(os.path.join("./synthetic_datasets", x)))
418
+ with open(os.path.join("./synthetic_datasets", latest_dataset), 'r') as f:
419
+ synthetic_data = json.load(f)
420
+ texts = [item["text"] for item in synthetic_data]
421
+ else:
422
+ texts = [t.strip() for t in training_data.split("\n\n") if t.strip()]
423
+
424
+ if not texts:
425
+ raise ModelTrainingError("No training data available!")
426
+
427
+ progress(0.3, f"✨ Tokenizing {len(texts)} examples...")
428
+
429
+ dataset = Dataset.from_dict({"text": texts})
430
+
431
+ def tokenize_function(examples):
432
+ return tokenizer(examples["text"], truncation=True, padding=True, max_length=256)
433
+
434
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
435
+
436
+ progress(0.4, "βš™οΈ Configuring training...")
437
+
438
+ training_args = TrainingArguments(
439
+ output_dir=f"./results_{model_name}",
440
+ num_train_epochs=num_epochs,
441
+ per_device_train_batch_size=batch_size,
442
+ gradient_accumulation_steps=4,
443
+ learning_rate=learning_rate,
444
+ logging_steps=50,
445
+ save_steps=500,
446
+ save_total_limit=2,
447
+ fp16=torch.cuda.is_available(),
448
+ report_to="none"
449
+ )
450
+
451
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
452
+
453
+ trainer = Trainer(
454
+ model=model,
455
+ args=training_args,
456
+ train_dataset=tokenized_dataset,
457
+ data_collator=data_collator,
458
+ )
459
+
460
+ progress(0.6, "πŸ’ͺ Training in progress...")
461
+ trainer.train()
462
+
463
+ progress(0.8, "πŸ’Ύ Saving model...")
464
+ output_dir = f"./trained_{model_name}"
465
+ trainer.save_model(output_dir)
466
+ tokenizer.save_pretrained(output_dir)
467
+
468
+ progress(0.9, "πŸ“€ Pushing to HuggingFace...")
469
+
470
+ try:
471
+ login(token=hf_token)
472
+ model.push_to_hub(model_name, token=hf_token)
473
+ tokenizer.push_to_hub(model_name, token=hf_token)
474
+ hub_url = f"https://huggingface.co/{username}/{model_name}"
475
+
476
+ return f"""πŸŽ‰ **TRAINING COMPLETE!**
477
+
478
+ βœ… Training successful
479
+ πŸ’Ύ Model saved locally
480
+ πŸ“€ Pushed to Hub
481
+ πŸ”— **Your model:** {hub_url}
482
+
483
+ **Stats:**
484
+ - Examples: {len(texts)}
485
+ - Epochs: {num_epochs}
486
+ - Learning rate: {learning_rate}
487
+
488
+ **Test it in the 'Test Model' tab!**"""
489
+
490
+ except Exception as e:
491
+ return f"βœ… Training done but upload failed: {str(e)}\nModel saved at: {output_dir}"# ==================== GRADIO INTERFACE ====================
492
+
493
+ def create_gradio_interface():
494
+ agent = ArchitechAgent()
495
+
496
+ with gr.Blocks(title="πŸ—οΈ Architech", theme=gr.themes.Soft()) as demo:
497
+ gr.Markdown("""
498
+ # πŸ—οΈ **Architech - Your AI Model Architect**
499
+
500
+ *Describe what you want, and I'll build it for you!*
501
+ """)
502
+
503
+ with gr.Tabs():
504
+ # Generate Dataset
505
+ with gr.Tab("πŸ“Š Generate Dataset"):
506
+ with gr.Row():
507
+ with gr.Column():
508
+ task_desc = gr.Textbox(label="Task Description", lines=3,
509
+ placeholder="E.g., 'Customer support chatbot for tech products'")
510
+ domain = gr.Dropdown(
511
+ choices=["technology", "healthcare", "finance", "general"],
512
+ label="Domain", value="general")
513
+ dataset_size = gr.Slider(50, 500, 100, step=50, label="Dataset Size")
514
+ format_type = gr.Dropdown(
515
+ choices=["conversational", "instruction"],
516
+ label="Format", value="conversational")
517
+ gen_btn = gr.Button("🎨 Generate Dataset", variant="primary")
518
+ with gr.Column():
519
+ gen_output = gr.Markdown()
520
+
521
+ gen_btn.click(
522
+ fn=agent.generate_synthetic_dataset_wrapper,
523
+ inputs=[task_desc, domain, dataset_size, format_type, gr.State("medium")],
524
+ outputs=gen_output
525
+ )
526
+
527
+ # Train Model
528
+ with gr.Tab("πŸš€ Train Model"):
529
+ with gr.Row():
530
+ with gr.Column():
531
+ task_desc_train = gr.Textbox(label="Task Description", lines=2)
532
+ model_name = gr.Textbox(label="Model Name", placeholder="my-awesome-model")
533
+ hf_token = gr.Textbox(label="HuggingFace Token", type="password")
534
+ use_synthetic = gr.Checkbox(label="Use Synthetic Data", value=True)
535
+
536
+ with gr.Accordion("βš™οΈ Advanced", open=False):
537
+ base_model = gr.Dropdown(
538
+ choices=["distilgpt2", "gpt2", "microsoft/DialoGPT-small"],
539
+ label="Base Model", value="distilgpt2")
540
+ learning_rate = gr.Slider(1e-5, 5e-4, 2e-4, label="Learning Rate")
541
+ num_epochs = gr.Slider(1, 5, 3, step=1, label="Epochs")
542
+ batch_size = gr.Slider(1, 4, 2, step=1, label="Batch Size")
543
+
544
+ train_btn = gr.Button("🎯 Train Model", variant="primary")
545
+
546
+ with gr.Column():
547
+ train_output = gr.Markdown()
548
+
549
+ train_btn.click(
550
+ fn=agent.train_custom_model,
551
+ inputs=[task_desc_train, gr.State(""), model_name, hf_token,
552
+ base_model, use_synthetic, gr.State("general"),
553
+ gr.State(100), learning_rate, num_epochs, batch_size],
554
+ outputs=train_output
555
+ )
556
+
557
+ # Test Model
558
+ with gr.Tab("πŸ§ͺ Test Model"):
559
+ with gr.Row():
560
+ with gr.Column():
561
+ test_model_name = gr.Textbox(label="Model Name",
562
+ placeholder="username/model-name")
563
+ test_token = gr.Textbox(label="HuggingFace Token", type="password")
564
+ load_btn = gr.Button("πŸ“₯ Load Model")
565
+
566
+ gr.Markdown("---")
567
+
568
+ test_prompt = gr.Textbox(label="Test Prompt", lines=3,
569
+ placeholder="Enter your prompt here...")
570
+ max_length = gr.Slider(50, 200, 100, label="Max Length")
571
+ temperature = gr.Slider(0.1, 1.0, 0.7, label="Temperature")
572
+
573
+ test_btn = gr.Button("🎯 Generate", variant="primary")
574
+
575
+ with gr.Column():
576
+ load_output = gr.Markdown()
577
+ test_output = gr.Markdown()
578
+
579
+ load_btn.click(
580
+ fn=model_inference.load_model,
581
+ inputs=[test_model_name, test_token],
582
+ outputs=load_output
583
+ )
584
+
585
+ test_btn.click(
586
+ fn=model_inference.generate_text,
587
+ inputs=[test_model_name, test_prompt, max_length, temperature, gr.State(0.9)],
588
+ outputs=test_output
589
+ )
590
+
591
+ # About
592
+ with gr.Tab("ℹ️ About"):
593
+ gr.Markdown("""
594
+ ## πŸ—οΈ Architech - Your AI Model Architect
595
+
596
+ ### Features:
597
+ - 🎨 **Generate Synthetic Data**: No training data? No problem!
598
+ - πŸš€ **Train Custom Models**: Fine-tune models for your specific needs
599
+ - πŸ§ͺ **Test Your Models**: Load and test your models instantly
600
+ - ⚑ **Rate Limited**: Fair usage for all users
601
+ - πŸ”’ **Secure**: Token-based authentication
602
+
603
+ ### How to Use:
604
+ 1. Generate synthetic training data for your task
605
+ 2. Train a custom model with your data
606
+ 3. Test and deploy your model!
607
+
608
+ ### Rate Limits:
609
+ - Dataset Generation: 10 per hour
610
+ - Model Training: 3 per hour
611
+ - Model Inference: 50 per hour
612
+
613
+ *Built with ❀️ using Gradio, Transformers, and HuggingFace*
614
+ """)
615
+
616
+ return demo
617
+
618
+ if __name__ == "__main__":
619
+ demo = create_gradio_interface()
620
+ demo.launch()