Spaces:
Running
Running
| import os | |
| import json | |
| import glob | |
| import zipfile | |
| import requests | |
| import pandas as pd | |
| import random | |
| import logging | |
| import re | |
| from translate import Translator | |
| from datetime import datetime | |
| from io import BytesIO | |
| from sklearn.model_selection import train_test_split | |
| from typing import List, Dict | |
| import nltk | |
| from nltk.sentiment import SentimentIntensityAnalyzer | |
| # File to store the download count | |
| COUNT_FILE = "download_count.txt" | |
| # Initialize the count (if file doesn't exist, create it with 20213) | |
| if not os.path.exists(COUNT_FILE): | |
| with open(COUNT_FILE, "w") as f: | |
| f.write("20213") | |
| # Function to get current download count | |
| def get_download_count(): | |
| with open(COUNT_FILE, "r") as f: | |
| return int(f.read().strip()) | |
| # Function to increment count on download click | |
| def update_download_count(): | |
| count = get_download_count() + 1 | |
| with open(COUNT_FILE, "w") as f: | |
| f.write(str(count)) | |
| return count | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Ensure the VADER lexicon is downloaded | |
| nltk.download('vader_lexicon') | |
| # Import gTTS for speech synthesis (install via: pip install gTTS) | |
| # from gtts import gTTS # (Uncomment if you plan to use speech functionality outside UI) | |
| # Optional: import ipywidgets for dropdowns in Jupyter/Colab | |
| try: | |
| from ipywidgets import Dropdown, VBox, Button, Output | |
| from IPython.display import display | |
| WIDGETS_AVAILABLE = True | |
| except ImportError: | |
| WIDGETS_AVAILABLE = False | |
| # Import gradio for UI | |
| import gradio as gr | |
| class RobustAIDatasetGenerator: | |
| def __init__(self): | |
| # Create output directory | |
| self.output_dir = 'comprehensive_ai_datasets' | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| # API keys for external data (replace or set via environment) | |
| self.news_api_key = os.environ.get('NEWS_API_KEY') | |
| self.github_api_key = os.environ.get('GITHUB_API_KEY') | |
| self.twitter_api_key = os.environ.get('TWITTER_API_KEY') | |
| self.reddit_api_key = os.environ.get('REDDIT_API_KEY',) | |
| self.pixabay_api_key = os.environ.get('PIXABAY_API_KEY') | |
| self.unsplash_api_key = os.environ.get('UNSPLASH_API_KEY') | |
| self.pexels_api_key = os.environ.get('PEXELS_API_KEY') | |
| self.github_token = os.getenv("GITHUB_ACCESS_TOKEN") | |
| if not self.github_token: | |
| logger.warning("GitHub Access Token not found! Set it using the environment variable GITHUB_ACCESS_TOKEN.") | |
| # ---------- Conversation Dataset ---------- | |
| def generate_conversation_dataset(self, keywords: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| conversation_templates = [ | |
| { | |
| "context": "Discussing {topic}", | |
| "question": "What are your thoughts on {topic}?", | |
| "answer": "Here's an insightful perspective on {topic}...", | |
| "argument_strength": random.uniform(0.5, 1.0) | |
| }, | |
| { | |
| "context": "Exploring different viewpoints about {topic}", | |
| "question": "Can you explain the main arguments for and against {topic}?", | |
| "answer": "The key arguments about {topic} include various perspectives...", | |
| "argument_strength": random.uniform(0.6, 1.0) | |
| } | |
| ] | |
| for _ in range(num_samples): | |
| keyword = keywords[0] if keywords else "general" | |
| template = random.choice(conversation_templates) | |
| record = { | |
| "task": "conversation_generation", | |
| "topic": keyword, | |
| "context": template["context"].format(topic=keyword), | |
| "question": template["question"].format(topic=keyword), | |
| "answer": template["answer"].format(topic=keyword), | |
| "argument_strength": template["argument_strength"], | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| return dataset | |
| # ---------- News Dataset ---------- | |
| def fetch_news_data(self, keywords: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| try: | |
| url = 'https://newsapi.org/v2/everything' | |
| keyword = keywords[0] if keywords else "" | |
| params = { | |
| 'apiKey': self.news_api_key, | |
| 'q': keyword, | |
| 'language': 'en', | |
| 'sortBy': 'relevancy', | |
| 'pageSize': min(num_samples, 100) | |
| } | |
| response = requests.get(url, params=params, timeout=10) | |
| news_data = response.json() | |
| for article in news_data.get('articles', []): | |
| record = { | |
| "task": "text_generation", | |
| "topic": keyword, | |
| "title": article.get('title', ''), | |
| "description": article.get('description', ''), | |
| "content": article.get('content', ''), | |
| "source": article.get('source', {}).get('name', ''), | |
| "url": article.get('url', ''), | |
| "published_at": article.get('publishedAt', datetime.now().isoformat()) | |
| } | |
| dataset.append(record) | |
| except Exception as e: | |
| logger.error(f"News API fetch failed: {e}") | |
| return dataset | |
| # ---------- Video Dataset using Pexels Videos API ---------- | |
| def fetch_video_data(self, keywords: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| keyword = keywords[0] if keywords else "" | |
| if self.pexels_api_key: | |
| try: | |
| url = "https://api.pexels.com/videos/search" | |
| headers = {"Authorization": self.pexels_api_key} | |
| params = { | |
| "query": keyword, | |
| "per_page": num_samples | |
| } | |
| response = requests.get(url, headers=headers, params=params, timeout=10) | |
| data = response.json() | |
| for video in data.get("videos", []): | |
| record = { | |
| "task": "video_generation", | |
| "keyword": keyword, | |
| "prompt": f"Generate a video about {keyword}", | |
| "video_url": video.get("video_files", [{}])[0].get("link", ""), | |
| "metadata": {"duration": video.get("duration", 0)}, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| except Exception as e: | |
| logger.error(f"Pexels Videos API fetch failed: {e}") | |
| if not dataset: | |
| simulated_video_urls = [ | |
| "https://www.pexels.com/video/pexels-video-123456/", | |
| "https://www.pexels.com/video/pexels-video-234567/", | |
| "https://www.pexels.com/video/pexels-video-345678/", | |
| "https://www.pexels.com/video/pexels-video-456789/", | |
| "https://www.pexels.com/video/pexels-video-567890/" | |
| ] | |
| for i in range(num_samples): | |
| record = { | |
| "task": "video_generation", | |
| "keyword": keyword, | |
| "prompt": f"Generate a video about {keyword} (sample {i+1})", | |
| "video_url": simulated_video_urls[i % len(simulated_video_urls)], | |
| "metadata": {"duration": random.randint(30, 300)}, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| return dataset | |
| # ---------- Translation Dataset using Google Translate API ---------- | |
| def fetch_translation_data(self, phrases: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| supported_languages = [ | |
| "en", "es", "fr", "de", "hi", "zh", "ar", "ru", "ja", "it", "pt", "ko", "tr", "nl", "sv", "id" | |
| ] | |
| for phrase in phrases: | |
| for lang in supported_languages: | |
| try: | |
| translator = Translator(to_lang=lang) | |
| translation = translator.translate(phrase) | |
| if not translation or translation.strip() == phrase: | |
| raise ValueError("Translation failed or returned the same text.") | |
| except Exception as e: | |
| logger.error(f"Translation failed for '{phrase}' to '{lang}': {e}") | |
| continue | |
| record = { | |
| "task": "translation", | |
| "source_text": phrase, | |
| "target_language": lang, | |
| "translated_text": translation, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| return dataset[:num_samples] if num_samples < len(dataset) else dataset | |
| # ---------- Code Generation Dataset (Enhanced with GitHub Code Search) ---------- | |
| def fetch_code_from_github(self, keyword: str, languages: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| headers = {"Authorization": f"token {self.github_token}"} | |
| if not self.github_token: | |
| logger.error("No GitHub Access Token provided. Please generate a PAT and set it as an environment variable.") | |
| return [] | |
| for lang in languages: | |
| try: | |
| query = f"{keyword} language:{lang}" | |
| url = f"https://api.github.com/search/code?q={query}&sort=stars&order=desc&per_page={num_samples}" | |
| response = requests.get(url, headers=headers, timeout=10) | |
| if response.status_code == 403 and "X-RateLimit-Remaining" in response.headers: | |
| reset_time = int(response.headers.get("X-RateLimit-Reset", 0)) - int(datetime.now().timestamp()) | |
| logger.warning(f"Rate limit exceeded. Sleeping for {reset_time} seconds.") | |
| import time | |
| time.sleep(reset_time + 1) | |
| return self.fetch_code_from_github(keyword, languages, num_samples) | |
| data = response.json() | |
| if not data.get("items"): | |
| logger.warning(f"No code samples found for {keyword} in {lang}.") | |
| continue | |
| for item in data["items"]: | |
| file_url = item.get("html_url") | |
| if not file_url: | |
| logger.warning(f"Skipping item with missing file URL: {item}") | |
| continue | |
| raw_url = file_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/") | |
| repo_name = item["repository"]["full_name"] | |
| code_snippet = self.fetch_raw_code(raw_url, keyword) | |
| if code_snippet.strip(): | |
| dataset.append({ | |
| "task": "code_generation", | |
| "problem_title": f"{keyword} in {lang}", | |
| "problem_description": f"Real implementation of {keyword} in {lang} fetched from GitHub.", | |
| "sample_solution": code_snippet, | |
| "language": lang, | |
| "repo_name": repo_name, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| if len(dataset) >= num_samples: | |
| return dataset | |
| except Exception as e: | |
| logger.error(f"GitHub code fetch failed for language {lang}: {e}") | |
| import time | |
| time.sleep(2) | |
| return dataset | |
| def fetch_raw_code(self, download_url: str, keyword: str) -> str: | |
| try: | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| response = requests.get(download_url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| raw_code = response.text if response.text else "" | |
| if not raw_code.strip(): | |
| logger.warning(f"Fetched empty or invalid content from {download_url}") | |
| return "" | |
| if keyword and keyword.lower() not in raw_code.lower(): | |
| return "" | |
| return raw_code[:2000] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to fetch raw code from {download_url}: {e}") | |
| return "" | |
| def build_code_generation_dataset(self, num_samples: int, keyword: str, prog_langs: List[str] = ["Python"]) -> List[Dict]: | |
| problems = self.fetch_code_from_github(keyword, prog_langs, num_samples) if keyword else [] | |
| dataset = [] | |
| for i in range(num_samples): | |
| problem = problems[i % len(problems)] if problems else {} | |
| dataset.append({ | |
| "task": "code_generation", | |
| "problem_title": problem.get("problem_title", ""), | |
| "problem_description": problem.get("problem_description", ""), | |
| "sample_solution": problem.get("sample_solution", ""), | |
| "language": problem.get("language", ""), | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| return dataset | |
| def fetch_sentiment_data(self, keywords: List[str], num_samples: int) -> List[Dict]: | |
| return fetch_sentiment_data_news(keywords[0] if keywords else "", num_samples) | |
| # ---------- Speech Generation using gTTS (Real Speech Synthesis) ---------- | |
| def generate_speech_file(self, text: str, language: str) -> Dict: | |
| try: | |
| from gtts import gTTS | |
| tts = gTTS(text=text, lang=language) | |
| filename = f"speech_{abs(hash(text)) % 10000}_{language}.mp3" | |
| file_path = os.path.join(self.output_dir, filename) | |
| tts.save(file_path) | |
| return { | |
| "audio_url": file_path, | |
| "speaker_info": { | |
| "library": "gTTS", | |
| "library_link": "https://pypi.org/project/gTTS/" | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"gTTS failed for text '{text}': {e}") | |
| return {"audio_url": "", "speaker_info": {}} | |
| def generate_speech_dataset(self, phrases: List[str], languages: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| for i in range(num_samples): | |
| phrase = phrases[i % len(phrases)] | |
| lang = languages[i % len(languages)] | |
| tts_result = self.generate_speech_file(phrase, lang) | |
| record = { | |
| "task": "speech_generation", | |
| "text": phrase, | |
| "language": lang, | |
| "audio_url": tts_result.get("audio_url", ""), | |
| "speaker_info": tts_result.get("speaker_info", {}), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| return dataset | |
| # ---------- Object Detection with Real Image Download & Analysis Simulation ---------- | |
| def download_image(self, image_url: str) -> BytesIO: | |
| try: | |
| response = requests.get(image_url, stream=True, timeout=10) | |
| if response.status_code == 200: | |
| return BytesIO(response.content) | |
| except Exception as e: | |
| logger.error(f"Failed to download image from {image_url}: {e}") | |
| return None | |
| def analyze_image(self, image_stream: BytesIO) -> Dict: | |
| width = random.randint(200, 800) | |
| height = random.randint(200, 800) | |
| num_boxes = random.randint(1, 3) | |
| boxes = [] | |
| for _ in range(num_boxes): | |
| x1 = random.randint(0, width//2) | |
| y1 = random.randint(0, height//2) | |
| x2 = random.randint(x1+10, width) | |
| y2 = random.randint(y1+10, height) | |
| boxes.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2}) | |
| return {"width": width, "height": height, "bounding_boxes": boxes} | |
| def build_object_detection_dataset(self, keywords: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| keyword = keywords[0] if keywords else "general" | |
| try: | |
| url = "https://api.unsplash.com/search/photos" | |
| params = { | |
| "client_id": self.unsplash_api_key, | |
| "query": keyword, | |
| "per_page": num_samples | |
| } | |
| response = requests.get(url, params=params, timeout=10) | |
| data = response.json() | |
| for i, result in enumerate(data.get("results", [])): | |
| image_url = result.get("urls", {}).get("regular", "") | |
| image_stream = self.download_image(image_url) | |
| if image_stream: | |
| analysis = self.analyze_image(image_stream) | |
| else: | |
| analysis = {"width": None, "height": None, "bounding_boxes": []} | |
| record = { | |
| "task": "object_detection", | |
| "image_url": image_url, | |
| "analysis": analysis, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| except Exception as e: | |
| logger.error(f"Unsplash API for object detection failed: {e}") | |
| while len(dataset) < num_samples: | |
| image_url = f"https://example.com/images/{keyword.replace(' ', '_')}_{len(dataset)+1}.jpg" | |
| analysis = {"width": 640, "height": 480, "bounding_boxes": [{"x1": 50, "y1": 50, "x2": 200, "y2": 200}]} | |
| dataset.append({ | |
| "task": "object_detection", | |
| "image_url": image_url, | |
| "analysis": analysis, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| return dataset[:num_samples] | |
| # ---------- Image Generation Dataset ---------- | |
| def fetch_pixabay_images(self, keyword: str, num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| try: | |
| url = "https://pixabay.com/api/" | |
| params = { | |
| "key": self.pixabay_api_key, | |
| "q": keyword, | |
| "image_type": "photo", | |
| "per_page": num_samples | |
| } | |
| response = requests.get(url, params=params, timeout=10) | |
| data = response.json() | |
| for hit in data.get("hits", []): | |
| record = { | |
| "task": "image_generation", | |
| "source": "Pixabay", | |
| "keyword": keyword, | |
| "image_url": hit.get("webformatURL", ""), | |
| "tags": hit.get("tags", ""), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| except Exception as e: | |
| logger.error(f"Pixabay API fetch failed: {e}") | |
| return dataset | |
| def fetch_unsplash_images(self, keyword: str, num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| try: | |
| url = "https://api.unsplash.com/search/photos" | |
| params = { | |
| "client_id": self.unsplash_api_key, | |
| "query": keyword, | |
| "per_page": num_samples | |
| } | |
| response = requests.get(url, params=params, timeout=10) | |
| data = response.json() | |
| for result in data.get("results", []): | |
| record = { | |
| "task": "image_generation", | |
| "source": "Unsplash", | |
| "keyword": keyword, | |
| "image_url": result.get("urls", {}).get("regular", ""), | |
| "description": result.get("description", "") or result.get("alt_description", ""), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| except Exception as e: | |
| logger.error(f"Unsplash API fetch failed: {e}") | |
| return dataset | |
| def fetch_image_generation_data(self, keywords: List[str], num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| keyword = keywords[0] if keywords else "" | |
| pixabay_data = self.fetch_pixabay_images(keyword, num_samples // 2) | |
| unsplash_data = self.fetch_unsplash_images(keyword, num_samples - len(pixabay_data)) | |
| dataset.extend(pixabay_data) | |
| dataset.extend(unsplash_data) | |
| return dataset | |
| # ---------- Comprehensive Dataset Generation ---------- | |
| def generate_dataset(self, task: str, keywords: List[str], extra_input: Dict, num_samples: int, split_label: str) -> pd.DataFrame: | |
| dataset = [] | |
| if task == "conversation": | |
| dataset = self.generate_conversation_dataset(keywords, num_samples) | |
| elif task == "news": | |
| dataset = self.fetch_news_data(keywords, num_samples) | |
| elif task == "video": | |
| dataset = self.fetch_video_data(keywords, num_samples) | |
| elif task == "translation": | |
| phrases = extra_input.get("phrases", []) | |
| dataset = self.fetch_translation_data(phrases, num_samples) | |
| elif task == "code": | |
| keyword = keywords[0] if keywords else None | |
| prog_langs = extra_input.get("prog_languages", ["Python", "Java", "C++", "Java Script"]) | |
| dataset = self.fetch_code_from_github(keyword, prog_langs, num_samples) | |
| elif task == "speech": | |
| phrases = extra_input.get("phrases", []) | |
| speech_langs = extra_input.get("speech_languages", ["en"]) | |
| dataset = self.generate_speech_dataset(phrases, speech_langs, num_samples) | |
| elif task == "object_detection": | |
| dataset = self.build_object_detection_dataset(keywords, num_samples) | |
| elif task == "text_generation": | |
| dataset = self.fetch_news_data(keywords, num_samples) | |
| elif task == "sentiment_analysis": | |
| dataset = self.fetch_sentiment_data(keywords, num_samples) | |
| elif task == "image_generation": | |
| dataset = self.fetch_image_generation_data(keywords, num_samples) | |
| else: | |
| for i in range(num_samples): | |
| dataset.append({ | |
| "task": "synthetic_generation", | |
| "topic": random.choice(keywords) if keywords else "general", | |
| "content": f"Synthetic content about {random.choice(keywords) if keywords else 'general'}", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| for record in dataset: | |
| record["split"] = split_label | |
| return pd.DataFrame(dataset[:num_samples]) | |
| # ---------- Saving the Dataset ---------- | |
| def save_dataset(self, df: pd.DataFrame, task_type: str): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| base_filename = os.path.join(self.output_dir, f"{task_type}_dataset_{timestamp}") | |
| df.to_csv(f"{base_filename}.csv", index=False) | |
| df.to_json(f"{base_filename}.json", orient='records', indent=2) | |
| df.to_excel(f"{base_filename}.xlsx", index=False) | |
| report = { | |
| "total_samples": len(df), | |
| "columns": list(df.columns), | |
| "timestamp": timestamp | |
| } | |
| with open(f"{base_filename}_report.json", 'w') as f: | |
| json.dump(report, f, indent=2) | |
| print(f"Dataset saved successfully:\n CSV: {base_filename}.csv\n JSON: {base_filename}.json\n Excel: {base_filename}.xlsx\n Report: {base_filename}_report.json") | |
| def fetch_sentiment_data_news(keyword: str, num_samples: int) -> List[Dict]: | |
| dataset = [] | |
| news_api_key = os.environ.get('NEWS_API_KEY') | |
| url = "https://newsapi.org/v2/everything" | |
| params = { | |
| "apiKey": news_api_key, | |
| "q": keyword, | |
| "language": "en", | |
| "pageSize": num_samples | |
| } | |
| try: | |
| response = requests.get(url, params=params, timeout=10) | |
| data = response.json() | |
| if "articles" not in data: | |
| logger.error(f"Unexpected response: {data}") | |
| return dataset | |
| sia = SentimentIntensityAnalyzer() | |
| for article in data.get("articles", []): | |
| text = article.get("description", "") or article.get("content", "") | |
| if not text: | |
| continue | |
| sentiment = sia.polarity_scores(text) | |
| record = { | |
| "task": "sentiment_analysis", | |
| "source": "NewsAPI", | |
| "keyword": keyword, | |
| "title": article.get("title", ""), | |
| "description": article.get("description", ""), | |
| "content": article.get("content", ""), | |
| "sentiment": sentiment, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| dataset.append(record) | |
| except Exception as e: | |
| logger.error(f"NewsAPI fetch failed: {e}") | |
| return dataset | |
| # ---------- Gradio UI Function ---------- | |
| download_count = 20213 # Initial count | |
| def generate_dataset_ui(task, keywords, phrases, prog_languages, num_samples, dataset_split): | |
| """Generates dataset, saves files, and returns ZIP for download.""" | |
| global download_count # Update counter here | |
| generator = RobustAIDatasetGenerator() | |
| extra_input = {} | |
| task = task.lower() | |
| if task == "speech": | |
| return "Speech task is Upcoming. Please select another task.", None, None, download_count | |
| if task in ["translation"]: | |
| extra_input["phrases"] = [p.strip() for p in phrases.split(",")] if phrases else [] | |
| if task == "code": | |
| extra_input["prog_languages"] = [p.strip() for p in prog_languages.split(",")] if prog_languages else [] | |
| keywords_list = [k.strip() for k in keywords.split(",")] if keywords else [] | |
| try: | |
| num_samples = int(num_samples) | |
| except: | |
| num_samples = 5 | |
| df = generator.generate_dataset(task, keywords_list, extra_input, num_samples, dataset_split) | |
| generator.save_dataset(df, task) | |
| # Locate the latest output files | |
| pattern = os.path.join(generator.output_dir, f"{task}_dataset_*.csv") | |
| csv_files = glob.glob(pattern) | |
| if not csv_files: | |
| return "No output files found.", None, None, download_count | |
| latest_csv = max(csv_files, key=os.path.getctime) | |
| base_filename = os.path.splitext(latest_csv)[0] | |
| # Define file paths | |
| csv_path = f"{base_filename}.csv" | |
| json_path = f"{base_filename}.json" | |
| excel_path = f"{base_filename}.xlsx" | |
| report_path = f"{base_filename}_report.json" | |
| zip_path = f"{base_filename}_all.zip" | |
| # Create a ZIP file with all generated outputs | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for file in [csv_path, json_path, excel_path, report_path]: | |
| if os.path.exists(file): | |
| zipf.write(file, os.path.basename(file)) | |
| # Read CSV for preview (first 5 rows) | |
| try: | |
| preview_df = pd.read_csv(csv_path).head(5) | |
| except Exception as e: | |
| preview_df = pd.DataFrame({"Error": [f"Could not read CSV preview: {e}"]}) | |
| # ✅ Update download counter when dataset is generated | |
| download_count += 1 | |
| return f"Dataset generated and saved. ZIP available for download.", zip_path, preview_df, download_count | |
| # ---------- Gradio UI ---------- | |
| instruction_text = """ | |
| # AI Dataset Generator | |
| - Select a task and enter relevant details. | |
| - Click "Generate Dataset" to create a dataset and **increment the download counter**. | |
| - Click the ZIP file to download it directly. | |
| """ | |
| task_options = ["conversation", "news", "video", "translation", "code", "speech (Upcoming)", "object_detection", "sentiment_analysis", "text_generation", "image_generation"] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(instruction_text) | |
| with gr.Row(): | |
| task_input = gr.Dropdown(choices=task_options, label="Select Task") | |
| keywords_input = gr.Textbox(label="Keywords (comma-separated)") | |
| with gr.Row(): | |
| phrases_input = gr.Textbox(label="Phrases (for Translation)", placeholder="Only required for translation") | |
| prog_languages_input = gr.Textbox(label="Programming Languages", placeholder="For Code Generation (Python, Java, etc.)") | |
| with gr.Row(): | |
| num_samples_input = gr.Number(value=5, label="Number of Records", precision=0) | |
| dataset_split_input = gr.Dropdown(choices=["training", "testing"], label="Dataset Split") | |
| file_info_text = gr.Textbox(label="Status", interactive=False) | |
| preview_table = gr.Dataframe(label="CSV Preview (first 5 rows)") | |
| download_file = gr.File(label="Download Generated ZIP") | |
| download_count_text = gr.Textbox(value=f"Total Downloads: {download_count}", interactive=False, label="Dataset Generated") | |
| def process_ui(task_sel, keywords, phrases, prog_langs, num_samples, dataset_split): | |
| task = task_sel.split(" ")[0] | |
| return generate_dataset_ui(task, keywords, phrases, prog_langs, num_samples, dataset_split) | |
| generate_btn = gr.Button("Generate Dataset") | |
| generate_btn.click(fn=process_ui, | |
| inputs=[task_input, keywords_input, phrases_input, prog_languages_input, num_samples_input, dataset_split_input], | |
| outputs=[file_info_text, download_file, preview_table, download_count_text]) # ✅ Updates counter when generating dataset | |
| if __name__ == '__main__': | |
| demo.launch(share=True) |