Spaces:
Sleeping
Sleeping
| # utils.py | |
| import os | |
| import re | |
| import json | |
| import requests | |
| import tempfile | |
| from bs4 import BeautifulSoup | |
| from typing import List, Literal | |
| from pydantic import BaseModel | |
| from pydub import AudioSegment, effects | |
| from transformers import pipeline | |
| import yt_dlp | |
| import tiktoken | |
| from groq import Groq # Ensure Groq client is imported | |
| import numpy as np | |
| import torch # Added to check CUDA availability | |
| class DialogueItem(BaseModel): | |
| speaker: Literal["Jane", "John"] | |
| text: str | |
| class Dialogue(BaseModel): | |
| dialogue: List[DialogueItem] | |
| # Initialize Whisper ASR pipeline | |
| asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=0 if torch.cuda.is_available() else -1) | |
| def truncate_text(text, max_tokens=2048): | |
| print("[LOG] Truncating text if needed.") | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) > max_tokens: | |
| print("[LOG] Text too long, truncating.") | |
| return tokenizer.decode(tokens[:max_tokens]) | |
| return text | |
| def extract_text_from_url(url): | |
| print("[LOG] Extracting text from URL:", url) | |
| try: | |
| response = requests.get(url) | |
| if response.status_code != 200: | |
| print(f"[ERROR] Failed to fetch URL: {url} with status code {response.status_code}") | |
| return "" | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| for script in soup(["script", "style"]): | |
| script.decompose() | |
| text = soup.get_text(separator=' ') | |
| print("[LOG] Text extraction from URL successful.") | |
| return text | |
| except Exception as e: | |
| print(f"[ERROR] Exception during text extraction from URL: {e}") | |
| return "" | |
| def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment: | |
| """ | |
| Shifts the pitch of an AudioSegment by a given number of semitones. | |
| Positive semitones shift the pitch up, negative shift it down. | |
| """ | |
| print(f"[LOG] Shifting pitch by {semitones} semitones.") | |
| new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0))) | |
| shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate}) | |
| return shifted_audio.set_frame_rate(audio.frame_rate) | |
| def is_sufficient(text: str, min_word_count: int = 500) -> bool: | |
| """ | |
| Determines if the fetched information meets the sufficiency criteria. | |
| :param text: Aggregated text from primary sources. | |
| :param min_word_count: Minimum number of words required. | |
| :return: True if sufficient, False otherwise. | |
| """ | |
| word_count = len(text.split()) | |
| print(f"[DEBUG] Aggregated word count: {word_count}") | |
| return word_count >= min_word_count | |
| def query_llm_for_additional_info(topic: str, existing_text: str) -> str: | |
| """ | |
| Queries the Groq API to retrieve additional relevant information from the LLM's knowledge base. | |
| :param topic: The research topic. | |
| :param existing_text: The text already gathered from primary sources. | |
| :return: Additional relevant information as a string. | |
| """ | |
| print("[LOG] Querying LLM for additional information.") | |
| # Define the system prompt for the LLM | |
| system_prompt = ( | |
| "You are an AI assistant with extensive knowledge up to 2023-10. " | |
| "Provide additional relevant information on the following topic based on your knowledge base.\n\n" | |
| f"Topic: {topic}\n\n" | |
| f"Existing Information: {existing_text}\n\n" | |
| "Please add more insightful details, facts, and perspectives to enhance the understanding of the topic." | |
| ) | |
| groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
| try: | |
| response = groq_client.chat.completions.create( | |
| messages=[{"role": "system", "content": system_prompt}], | |
| model="llama-3.3-70b-versatile", | |
| max_tokens=1024, | |
| temperature=0.7 | |
| ) | |
| except Exception as e: | |
| print("[ERROR] Groq API error during fallback:", e) | |
| return "" | |
| additional_info = response.choices[0].message.content.strip() | |
| print("[DEBUG] Additional information from LLM:") | |
| print(additional_info) | |
| return additional_info | |
| def research_topic(topic: str) -> str: | |
| # Sources: | |
| sources = { | |
| "BBC": "https://feeds.bbci.co.uk/news/rss.xml", | |
| "CNN": "http://rss.cnn.com/rss/edition.rss", | |
| "Associated Press": "https://apnews.com/apf-topnews", | |
| "NDTV": "https://www.ndtv.com/rss/top-stories", | |
| "Times of India": "https://timesofindia.indiatimes.com/rssfeeds/296589292.cms", | |
| "The Hindu": "https://www.thehindu.com/news/national/kerala/rssfeed.xml", | |
| "Economic Times": "https://economictimes.indiatimes.com/rssfeeds/1977021501.cms", | |
| "Google News - Custom": f"https://news.google.com/rss/search?q={requests.utils.quote(topic)}&hl=en-IN&gl=IN&ceid=IN:en", | |
| } | |
| summary_parts = [] | |
| # Wikipedia summary | |
| wiki_summary = fetch_wikipedia_summary(topic) | |
| if wiki_summary: | |
| summary_parts.append(f"From Wikipedia: {wiki_summary}") | |
| # For each news RSS | |
| for name, url in sources.items(): | |
| try: | |
| items = fetch_rss_feed(url) | |
| if not items: | |
| continue | |
| # Use simple keyword matching | |
| title, desc, link = find_relevant_article(items, topic, min_match=2) | |
| if link: | |
| article_text = fetch_article_text(link) | |
| if article_text: | |
| summary_parts.append(f"From {name}: {article_text}") | |
| else: | |
| # If no main text extracted, use title/desc | |
| summary_parts.append(f"From {name}: {title} - {desc}") | |
| except Exception as e: | |
| print(f"[ERROR] Error fetching from {name} RSS feed:", e) | |
| continue | |
| aggregated_info = " ".join(summary_parts) | |
| print("[DEBUG] Aggregated information from primary sources.") | |
| print(aggregated_info) | |
| if not is_sufficient(aggregated_info): | |
| print("[LOG] Insufficient information from primary sources. Initiating fallback to LLM.") | |
| additional_info = query_llm_for_additional_info(topic, aggregated_info) | |
| if additional_info: | |
| aggregated_info += " " + additional_info | |
| else: | |
| print("[ERROR] Failed to retrieve additional information from LLM.") | |
| if not aggregated_info: | |
| # No info found at all | |
| print("[LOG] No information found for the topic.") | |
| return f"Sorry, I couldn't find recent information on '{topic}'." | |
| return aggregated_info | |
| def fetch_wikipedia_summary(topic: str) -> str: | |
| print("[LOG] Fetching Wikipedia summary for:", topic) | |
| try: | |
| # 1. Search for the topic | |
| search_url = f"https://en.wikipedia.org/w/api.php?action=opensearch&search={requests.utils.quote(topic)}&limit=1&namespace=0&format=json" | |
| resp = requests.get(search_url) | |
| if resp.status_code != 200: | |
| print(f"[ERROR] Failed to fetch Wikipedia search results for topic: {topic}") | |
| return "" | |
| data = resp.json() | |
| if len(data) > 1 and data[1]: | |
| title = data[1][0] | |
| # 2. Fetch summary | |
| summary_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{requests.utils.quote(title)}" | |
| s_resp = requests.get(summary_url) | |
| if s_resp.status_code == 200: | |
| s_data = s_resp.json() | |
| if "extract" in s_data: | |
| print("[LOG] Wikipedia summary fetched successfully.") | |
| return s_data["extract"] | |
| print("[LOG] No Wikipedia summary found for topic:", topic) | |
| return "" | |
| except Exception as e: | |
| print(f"[ERROR] Exception during Wikipedia summary fetch: {e}") | |
| return "" | |
| def fetch_rss_feed(feed_url: str) -> list: | |
| print("[LOG] Fetching RSS feed:", feed_url) | |
| try: | |
| resp = requests.get(feed_url) | |
| if resp.status_code != 200: | |
| print(f"[ERROR] Failed to fetch RSS feed: {feed_url} with status code {resp.status_code}") | |
| return [] | |
| # Use html.parser instead of xml to avoid needing lxml or other parsers. | |
| soup = BeautifulSoup(resp.content, "html.parser") | |
| items = soup.find_all("item") | |
| print(f"[LOG] Number of items fetched from {feed_url}: {len(items)}") | |
| return items | |
| except Exception as e: | |
| print(f"[ERROR] Exception occurred while fetching RSS feed {feed_url}: {e}") | |
| return [] | |
| def find_relevant_article(items, topic: str, min_match=2) -> tuple: | |
| """ | |
| Searches for relevant articles based on topic keywords. | |
| :param items: List of RSS feed items | |
| :param topic: Topic string | |
| :param min_match: Minimum number of keyword matches required | |
| :return: (title, description, link) or (None, None, None) | |
| """ | |
| print("[LOG] Finding relevant articles...") | |
| keywords = re.findall(r'\w+', topic.lower()) | |
| print(f"[LOG] Topic keywords: {keywords}") | |
| for item in items: | |
| title = item.find("title").get_text().strip() if item.find("title") else "" | |
| description = item.find("description").get_text().strip() if item.find("description") else "" | |
| text = f"{title.lower()} {description.lower()}" | |
| matches = sum(1 for kw in keywords if kw in text) | |
| print(f"[DEBUG] Checking article: '{title}' | Matches: {matches}/{len(keywords)}") | |
| if matches >= min_match: | |
| link = item.find("link").get_text().strip() if item.find("link") else "" | |
| print(f"[LOG] Relevant article found: {title}") | |
| return title, description, link | |
| print("[LOG] No relevant articles found based on the current matching criteria.") | |
| return None, None, None | |
| def fetch_article_text(link: str) -> str: | |
| print("[LOG] Fetching article text from:", link) | |
| if not link: | |
| print("[LOG] No link provided for fetching article text.") | |
| return "" | |
| try: | |
| resp = requests.get(link) | |
| if resp.status_code != 200: | |
| print(f"[ERROR] Failed to fetch article from link: {link} with status code {resp.status_code}") | |
| return "" | |
| soup = BeautifulSoup(resp.text, 'html.parser') | |
| # This is site-specific. We'll try a generic approach: | |
| # Just take all paragraphs: | |
| paragraphs = soup.find_all("p") | |
| text = " ".join(p.get_text() for p in paragraphs[:5]) # first 5 paragraphs for more context | |
| print("[LOG] Article text fetched successfully.") | |
| return text.strip() | |
| except Exception as e: | |
| print(f"[ERROR] Error fetching article text: {e}") | |
| return "" | |
| def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str): | |
| print("[LOG] Generating script with tone:", tone, "and length:", target_length) | |
| groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
| # Map target_length to word ranges | |
| length_mapping = { | |
| "1-3 Mins": (200, 450), | |
| "3-5 Mins": (450, 750), | |
| "5-10 Mins": (750, 1500), | |
| "10-20 Mins": (1500, 3000) | |
| } | |
| min_words, max_words = length_mapping.get(target_length, (200, 450)) | |
| # Adjust tone description for clarity in prompt | |
| tone_description = { | |
| "Humorous": "funny and exciting, makes people chuckle", | |
| "Formal": "business-like, well-structured, professional", | |
| "Casual": "like a conversation between close friends, relaxed and informal", | |
| "Youthful": "like how teenagers might chat, energetic and lively" | |
| } | |
| chosen_tone = tone_description.get(tone, "casual") | |
| # Construct the prompt with clear instructions for JSON output | |
| prompt = ( | |
| f"{system_prompt}\n" | |
| f"TONE: {chosen_tone}\n" | |
| f"TARGET LENGTH: {target_length} ({min_words}-{max_words} words)\n" | |
| f"INPUT TEXT: {input_text}\n\n" | |
| "Please provide the output in the following JSON format without any additional text:\n\n" | |
| "{\n" | |
| ' "dialogue": [\n' | |
| ' {\n' | |
| ' "speaker": "Jane",\n' | |
| ' "text": "..." \n' | |
| ' },\n' | |
| ' {\n' | |
| ' "speaker": "John",\n' | |
| ' "text": "..." \n' | |
| ' }\n' | |
| " ]\n" | |
| "}" | |
| ) | |
| print("[LOG] Sending prompt to Groq:") | |
| print(prompt) # Log the prompt being sent | |
| try: | |
| response = groq_client.chat.completions.create( | |
| messages=[{"role": "system", "content": prompt}], | |
| model="llama-3.3-70b-versatile", | |
| max_tokens=2048, | |
| temperature=0.7 | |
| ) | |
| except Exception as e: | |
| print("[ERROR] Groq API error:", e) | |
| raise ValueError(f"Error communicating with Groq API: {str(e)}") | |
| # Log the raw response content for debugging | |
| raw_content = response.choices[0].message.content.strip() | |
| print("[DEBUG] Raw API response content:") | |
| print(raw_content) | |
| # Attempt to extract JSON from the response | |
| content = raw_content.replace('```json', '').replace('```', '').strip() | |
| start_index = content.find('{') | |
| end_index = content.rfind('}') | |
| if start_index == -1 or end_index == -1: | |
| print("[ERROR] Failed to parse dialogue. No JSON found.") | |
| print("[ERROR] Entire response content:") | |
| print(content) | |
| raise ValueError("Failed to parse dialogue: Could not find JSON object in response.") | |
| json_str = content[start_index:end_index+1].strip() | |
| print("[DEBUG] Extracted JSON string:") | |
| print(json_str) | |
| try: | |
| data = json.loads(json_str) | |
| print("[LOG] Script generated successfully.") | |
| return Dialogue(**data) | |
| except json.JSONDecodeError as e: | |
| print("[ERROR] JSON decoding failed:", e) | |
| print("[ERROR] Response content causing failure:") | |
| print(content) | |
| raise ValueError(f"Failed to parse dialogue: {str(e)}") | |
| def generate_audio_mp3(text: str, speaker: str) -> str: | |
| try: | |
| print(f"[LOG] Generating audio for speaker: {speaker}") | |
| # Define Deepgram API endpoint | |
| deepgram_api_url = "https://api.deepgram.com/v1/speak" | |
| # Prepare query parameters | |
| params = { | |
| "model": "aura-asteria-en", # Default model; adjust if needed | |
| # You can add more parameters here as needed, e.g., bit_rate, sample_rate, etc. | |
| } | |
| # Override model if needed based on speaker | |
| if speaker == "Jane": | |
| params["model"] = "aura-asteria-en" # Female voice | |
| elif speaker == "John": | |
| params["model"] = "aura-orpheus-en" # Male voice | |
| else: | |
| raise ValueError(f"Unknown speaker: {speaker}") | |
| # Prepare headers | |
| headers = { | |
| "Accept": "audio/mpeg", # Request MP3 files | |
| "Content-Type": "application/json", | |
| "Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}" | |
| } | |
| # Prepare body | |
| body = { | |
| "text": text | |
| } | |
| print("[LOG] Sending TTS request to Deepgram...") | |
| # Make the POST request to Deepgram's TTS API | |
| response = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True) | |
| if response.status_code != 200: | |
| print(f"[ERROR] Deepgram TTS API returned status code {response.status_code}: {response.text}") | |
| raise ValueError(f"Deepgram TTS API error: {response.status_code} - {response.text}") | |
| # Verify Content-Type | |
| content_type = response.headers.get('Content-Type', '') | |
| if 'audio/mpeg' not in content_type: | |
| print("[ERROR] Unexpected Content-Type received from Deepgram:", content_type) | |
| print("[ERROR] Response content:", response.text) | |
| raise ValueError("Unexpected Content-Type received from Deepgram.") | |
| # Save the streamed audio to a temporary MP3 file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| mp3_file.write(chunk) | |
| mp3_temp_path = mp3_file.name | |
| print(f"[LOG] Audio received from Deepgram and saved at: {mp3_temp_path}") | |
| # Normalize audio volume | |
| audio_seg = AudioSegment.from_file(mp3_temp_path, format="mp3") | |
| audio_seg = effects.normalize(audio_seg) | |
| # Removed pitch shifting for male voice | |
| # Previously: | |
| # if speaker == "John": | |
| # semitones = -5 # Shift down by 5 semitones for a deeper voice | |
| # audio_seg = pitch_shift(audio_seg, semitones=semitones) | |
| # print(f"[LOG] Applied pitch shift to John's voice by {semitones} semitones.") | |
| # Export the final audio as MP3 | |
| final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name | |
| audio_seg.export(final_mp3_path, format="mp3") | |
| print("[LOG] Audio post-processed and saved at:", final_mp3_path) | |
| # Clean up the initial MP3 file | |
| if os.path.exists(mp3_temp_path): | |
| os.remove(mp3_temp_path) | |
| print(f"[LOG] Removed temporary MP3 file: {mp3_temp_path}") | |
| return final_mp3_path | |
| except Exception as e: | |
| print("[ERROR] Error generating audio:", e) | |
| raise ValueError(f"Error generating audio: {str(e)}") | |
| def transcribe_youtube_video(video_url: str) -> str: | |
| print("[LOG] Transcribing YouTube video:", video_url) | |
| fd, audio_file = tempfile.mkstemp(suffix=".wav") | |
| os.close(fd) | |
| ydl_opts = { | |
| 'format': 'bestaudio/best', | |
| 'outtmpl': audio_file, | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| 'preferredquality': '192' | |
| }], | |
| 'quiet': True, | |
| 'no_warnings': True, | |
| } | |
| try: | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([video_url]) | |
| except yt_dlp.utils.DownloadError as e: | |
| print("[ERROR] yt-dlp download error:", e) | |
| raise ValueError(f"Error downloading YouTube video: {str(e)}") | |
| print("[LOG] Audio downloaded at:", audio_file) | |
| try: | |
| # Run ASR on the downloaded audio | |
| result = asr_pipeline(audio_file) | |
| transcript = result["text"] | |
| print("[LOG] Transcription completed.") | |
| return transcript.strip() | |
| except Exception as e: | |
| print("[ERROR] ASR transcription error:", e) | |
| raise ValueError(f"Error transcribing YouTube video: {str(e)}") | |
| finally: | |
| # Clean up the downloaded audio file | |
| if os.path.exists(audio_file): | |
| os.remove(audio_file) | |
| print(f"[LOG] Removed temporary audio file: {audio_file}") | |