Spaces:
Build error
Build error
| import os | |
| import re | |
| import json | |
| import requests | |
| import tempfile | |
| import random | |
| import math | |
| import numpy as np | |
| import torch | |
| import time | |
| from bs4 import BeautifulSoup | |
| from typing import List, Literal, Optional | |
| from pydantic import BaseModel | |
| from pydub import AudioSegment, effects | |
| from transformers import pipeline | |
| import tiktoken | |
| from groq import Groq | |
| import streamlit as st # If you use Streamlit for session state | |
| from report_structure import generate_report # Your PDF generator | |
| from tavily import TavilyClient # For search | |
| ############################################################################### | |
| # DATA MODELS | |
| ############################################################################### | |
| class DialogueItem(BaseModel): | |
| speaker: Literal["Jane", "John"] | |
| display_speaker: str = "Jane" | |
| text: str | |
| class Dialogue(BaseModel): | |
| dialogue: List[DialogueItem] | |
| ############################################################################### | |
| # HYBRID RATE-LIMIT HANDLER | |
| ############################################################################### | |
| def call_llm_with_retry(groq_client, **payload): | |
| """ | |
| Wraps groq_client.chat.completions.create(**payload) in a retry loop | |
| to catch rate-limit errors or service unavailable (503) errors. | |
| If we see “try again in XXs,” or detect a 503 error, we parse the wait time, | |
| sleep, then retry. We also do a short sleep (0.3s) after each successful call. | |
| """ | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"[DEBUG] call_llm_with_retry attempt {attempt+1}") | |
| response = groq_client.chat.completions.create(**payload) | |
| # Short sleep to avoid bursting usage | |
| time.sleep(0.3) | |
| print("[DEBUG] LLM call succeeded, returning response.") | |
| return response | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| print(f"[WARN] call_llm_with_retry attempt {attempt+1} failed: {e}") | |
| if ("rate_limit_exceeded" in err_str or "try again in" in err_str or "503" in err_str): | |
| wait_time = 60.0 | |
| match = re.search(r'try again in (\d+(?:\.\d+)?)s', str(e), re.IGNORECASE) | |
| if match: | |
| wait_time = float(match.group(1)) + 1.0 | |
| elif "503" in err_str: | |
| wait_time = 60.0 | |
| print(f"[WARN] Detected error (rate limit or 503). Sleeping for {wait_time:.1f}s, then retrying.") | |
| time.sleep(wait_time) | |
| else: | |
| raise | |
| raise RuntimeError("Exceeded max_retries due to repeated rate limit or other errors.") | |
| ############################################################################### | |
| # TRUNCATION | |
| ############################################################################### | |
| def truncate_text_tokens(text: str, max_tokens: int) -> str: | |
| """ | |
| Truncates 'text' to 'max_tokens' tokens. Used for controlling maximum | |
| total text size after scraping. | |
| """ | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) > max_tokens: | |
| truncated = tokenizer.decode(tokens[:max_tokens]) | |
| print(f"[DEBUG] Truncating from {len(tokens)} tokens to {max_tokens} tokens.") | |
| return truncated | |
| return text | |
| def truncate_text_for_llm(text: str, max_tokens: int = 1024) -> str: | |
| """ | |
| Typical truncation for partial merges or final calls. | |
| """ | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) > max_tokens: | |
| truncated = tokenizer.decode(tokens[:max_tokens]) | |
| print(f"[DEBUG] Truncating text from {len(tokens)} to {max_tokens} tokens for LLM.") | |
| return truncated | |
| return text | |
| ############################################################################### | |
| # PITCH SHIFT (Optional) | |
| ############################################################################### | |
| def pitch_shift(audio: AudioSegment, semitones: int) -> AudioSegment: | |
| print(f"[LOG] Shifting pitch by {semitones} semitones.") | |
| new_sample_rate = int(audio.frame_rate * (2.0 ** (semitones / 12.0))) | |
| shifted_audio = audio._spawn(audio.raw_data, overrides={'frame_rate': new_sample_rate}) | |
| return shifted_audio.set_frame_rate(audio.frame_rate) | |
| ############################################################################### | |
| # PODCAST SCRIPT GENERATION (Single Call) | |
| ############################################################################### | |
| def generate_script( | |
| system_prompt: str, | |
| input_text: str, | |
| tone: str, | |
| target_length: str, | |
| host_name: str = "Jane", | |
| guest_name: str = "John", | |
| sponsor_style: str = "Separate Break", | |
| sponsor_provided=None | |
| ): | |
| """ | |
| If you do a single call to generate the entire script. | |
| Uses DEEPSEEK_R1. Just ensure you parse the JSON. | |
| """ | |
| print("[LOG] Generating script with tone:", tone, "and length:", target_length) | |
| language_selection = st.session_state.get("language_selection", "English (American)") | |
| if (host_name == "Jane" or not host_name) and language_selection in ["English (Indian)", "Hinglish", "Hindi"]: | |
| host_name = "Isha" | |
| if (guest_name == "John" or not guest_name) and language_selection in ["English (Indian)", "Hinglish", "Hindi"]: | |
| guest_name = "Aarav" | |
| words_per_minute = 150 | |
| numeric_minutes = 3 | |
| match = re.search(r"(\d+)", target_length) | |
| if match: | |
| numeric_minutes = int(match.group(1)) | |
| min_words = max(50, numeric_minutes * 100) | |
| max_words = numeric_minutes * words_per_minute | |
| tone_map = { | |
| "Humorous": "funny and exciting, makes people chuckle", | |
| "Formal": "business-like, well-structured, professional", | |
| "Casual": "like a conversation between close friends, relaxed and informal", | |
| "Youthful": "like how teenagers might chat, energetic and lively" | |
| } | |
| chosen_tone = tone_map.get(tone, "casual") | |
| if sponsor_provided: | |
| if sponsor_style == "Separate Break": | |
| sponsor_instructions = ( | |
| "If sponsor content is provided, include it in a separate ad break (~30 seconds). " | |
| "Use 'Now a word from our sponsor...' and end with 'Back to the show', etc." | |
| ) | |
| else: | |
| sponsor_instructions = ( | |
| "If sponsor content is provided, blend it naturally (~30 seconds) into conversation. " | |
| "Avoid abrupt transitions." | |
| ) | |
| else: | |
| sponsor_instructions = "" | |
| prompt = ( | |
| f"{system_prompt}\n" | |
| f"TONE: {chosen_tone}\n" | |
| f"TARGET LENGTH: {target_length} (~{min_words}-{max_words} words)\n" | |
| f"INPUT TEXT: {input_text}\n\n" | |
| f"# Sponsor Style Instruction:\n{sponsor_instructions}\n\n" | |
| "Please provide the output in the following JSON format without any extra text:\n" | |
| "{\n" | |
| ' "dialogue": [\n' | |
| ' { "speaker": "Jane", "text": "..." },\n' | |
| ' { "speaker": "John", "text": "..." }\n' | |
| " ]\n" | |
| "}" | |
| ) | |
| if language_selection == "Hinglish": | |
| prompt += "\n\nPlease generate the script in Romanized Hindi.\n" | |
| elif language_selection == "Hindi": | |
| prompt += "\n\nPlease generate the script exclusively in Hindi.\n" | |
| print("[LOG] Sending script generation prompt to LLM.") | |
| try: | |
| headers = { | |
| "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY')}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": "deepseek/deepseek-r1", | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": 2048, | |
| "temperature": 0.7 | |
| } | |
| resp = requests.post("https://openrouter.ai/api/v1/chat/completions", | |
| headers=headers, data=json.dumps(data)) | |
| resp.raise_for_status() | |
| raw_content = resp.json()["choices"][0]["message"]["content"].strip() | |
| except Exception as e: | |
| print("[ERROR] LLM error generating script:", e) | |
| raise ValueError(f"Error generating script: {str(e)}") | |
| start_idx = raw_content.find("{") | |
| end_idx = raw_content.rfind("}") | |
| if start_idx == -1 or end_idx == -1: | |
| raise ValueError("No JSON found in LLM response for script generation.") | |
| json_str = raw_content[start_idx:end_idx+1] | |
| try: | |
| data_js = json.loads(json_str) | |
| dialogue_list = data_js.get("dialogue", []) | |
| # Adjust speaker names if they match | |
| for d in dialogue_list: | |
| raw_speaker = d.get("speaker", "Jane") | |
| if raw_speaker.lower() == host_name.lower(): | |
| d["speaker"] = "Jane" | |
| d["display_speaker"] = host_name | |
| elif raw_speaker.lower() == guest_name.lower(): | |
| d["speaker"] = "John" | |
| d["display_speaker"] = guest_name | |
| else: | |
| d["speaker"] = "Jane" | |
| d["display_speaker"] = raw_speaker | |
| new_dialogue_items = [] | |
| for d in dialogue_list: | |
| if "display_speaker" not in d: | |
| d["display_speaker"] = d["speaker"] | |
| new_dialogue_items.append(DialogueItem(**d)) | |
| return Dialogue(dialogue=new_dialogue_items) | |
| except json.JSONDecodeError as e: | |
| print("[ERROR] JSON decoding failed for script generation:", e) | |
| raise ValueError(f"Script parse error: {str(e)}") | |
| except Exception as e: | |
| print("[ERROR] Unknown error parsing script JSON:", e) | |
| raise ValueError(f"Script parse error: {str(e)}") | |
| ############################################################################### | |
| # YOUTUBE TRANSCRIPTION (RAPIDAPI) | |
| ############################################################################### | |
| def transcribe_youtube_video(video_url: str) -> str: | |
| print("[LOG] Transcribing YouTube video:", video_url) | |
| match = re.search(r"(?:v=|/)([0-9A-Za-z_-]{11})", video_url) | |
| if not match: | |
| raise ValueError(f"Invalid YouTube URL: {video_url}, cannot extract video ID.") | |
| video_id = match.group(1) | |
| print("[LOG] Extracted video ID:", video_id) | |
| base_url = "https://youtube-transcriptor.p.rapidapi.com/transcript" | |
| params = {"video_id": video_id, "lang": "en"} | |
| headers = { | |
| "x-rapidapi-host": "youtube-transcriptor.p.rapidapi.com", | |
| "x-rapidapi-key": os.environ.get("RAPIDAPI_KEY") | |
| } | |
| try: | |
| resp = requests.get(base_url, headers=headers, params=params, timeout=30) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| if not isinstance(data, list) or not data: | |
| raise ValueError(f"Unexpected transcript format or empty transcript: {data}") | |
| transcript_as_text = data[0].get("transcriptionAsText", "").strip() | |
| if not transcript_as_text: | |
| raise ValueError("transcriptionAsText missing or empty in RapidAPI response.") | |
| print("[LOG] Transcript retrieval successful. Sample:", transcript_as_text[:200], "...") | |
| return transcript_as_text | |
| except Exception as e: | |
| print("[ERROR] YouTube transcription error:", e) | |
| raise ValueError(f"Error transcribing YouTube video: {str(e)}") | |
| ############################################################################### | |
| # AUDIO GENERATION (TTS) AND BG MUSIC MIX | |
| ############################################################################### | |
| def _preprocess_text_for_tts(text: str, speaker: str) -> str: | |
| text = re.sub(r"\bNo\.\b", "Number", text, flags=re.IGNORECASE) | |
| text = re.sub(r"\b(?i)SaaS\b", "sass", text) | |
| abbreviations_as_words = {"NASA", "NATO", "UNESCO"} | |
| def insert_periods_for_abbrev(m): | |
| abbr = m.group(0) | |
| if abbr in abbreviations_as_words: | |
| return abbr | |
| return ".".join(list(abbr)) + "." | |
| text = re.sub(r"\b([A-Z]{2,})\b", insert_periods_for_abbrev, text) | |
| text = re.sub(r"\.\.", ".", text) | |
| def remove_periods_for_tts(m): | |
| return m.group().replace(".", " ").strip() | |
| text = re.sub(r"[A-Z]\.[A-Z](?:\.[A-Z])*\.", remove_periods_for_tts, text) | |
| text = re.sub(r"-", " ", text) | |
| text = re.sub(r"\b(ha(ha)?|heh|lol)\b", "(* laughs *)", text, flags=re.IGNORECASE) | |
| text = re.sub(r"\bsigh\b", "(* sighs *)", text, flags=re.IGNORECASE) | |
| text = re.sub(r"\b(groan|moan)\b", "(* groans *)", text, flags=re.IGNORECASE) | |
| if speaker != "Jane": | |
| def insert_thinking_pause(m): | |
| wd = m.group(1) | |
| if random.random() < 0.3: | |
| filler = random.choice(["hmm,", "well,", "let me see,"]) | |
| return f"{wd}..., {filler}" | |
| else: | |
| return f"{wd}...," | |
| keywords_pattern = r"\b(important|significant|crucial|point|topic)\b" | |
| text = re.sub(keywords_pattern, insert_thinking_pause, text, flags=re.IGNORECASE) | |
| conj_pattern = r"\b(and|but|so|because|however)\b" | |
| text = re.sub(conj_pattern, lambda m: f"{m.group()}...", text, flags=re.IGNORECASE) | |
| text = re.sub(r"\b(uh|um|ah)\b", "", text, flags=re.IGNORECASE) | |
| def capitalize_after_sentence(m): | |
| return m.group().upper() | |
| text = re.sub(r'(^\s*\w)|([.!?]\s*\w)', capitalize_after_sentence, text) | |
| return text.strip() | |
| def generate_audio_mp3(text: str, speaker: str) -> str: | |
| """ | |
| Uses Deepgram (English) or Murf (Indian/Hinglish/Hindi) for TTS. | |
| """ | |
| print(f"[LOG] Generating TTS for speaker={speaker}") | |
| language_selection = st.session_state.get("language_selection", "English (American)") | |
| try: | |
| if language_selection == "English (American)": | |
| print("[LOG] Using Deepgram for American English TTS.") | |
| processed_text = text if speaker in ["Jane", "John"] else _preprocess_text_for_tts(text, speaker) | |
| deepgram_api_url = "https://api.deepgram.com/v1/speak" | |
| params = {"model": "aura-asteria-en"} if speaker != "John" else {"model": "aura-zeus-en"} | |
| headers = { | |
| "Accept": "audio/mpeg", | |
| "Content-Type": "application/json", | |
| "Authorization": f"Token {os.environ.get('DEEPGRAM_API_KEY')}" | |
| } | |
| body = {"text": processed_text} | |
| r = requests.post(deepgram_api_url, params=params, headers=headers, json=body, stream=True) | |
| r.raise_for_status() | |
| content_type = r.headers.get("Content-Type", "") | |
| if "audio/mpeg" not in content_type: | |
| raise ValueError("Unexpected content-type from Deepgram TTS.") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as mp3_file: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if chunk: | |
| mp3_file.write(chunk) | |
| mp3_path = mp3_file.name | |
| audio_seg = AudioSegment.from_file(mp3_path, format="mp3") | |
| audio_seg = effects.normalize(audio_seg) | |
| final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name | |
| audio_seg.export(final_mp3_path, format="mp3") | |
| if os.path.exists(mp3_path): | |
| os.remove(mp3_path) | |
| return final_mp3_path | |
| else: | |
| print("[LOG] Using Murf API for TTS. Language=", language_selection) | |
| from indic_transliteration.sanscript import transliterate, DEVANAGARI, IAST | |
| if language_selection == "Hinglish": | |
| text = transliterate(text, DEVANAGARI, IAST) | |
| api_key = os.environ.get("MURF_API_KEY") | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", | |
| "api-key": api_key | |
| } | |
| multi_native_locale = "hi-IN" if language_selection in ["Hinglish", "Hindi"] else "en-IN" | |
| if language_selection == "English (Indian)": | |
| voice_id = "en-IN-aarav" if speaker == "John" else "en-IN-isha" | |
| elif language_selection in ["Hindi", "Hinglish"]: | |
| voice_id = "hi-IN-kabir" if speaker == "John" else "hi-IN-shweta" | |
| else: | |
| voice_id = "en-IN-aarav" if speaker == "John" else "en-IN-isha" | |
| payload = { | |
| "audioDuration": 0, | |
| "channelType": "MONO", | |
| "encodeAsBase64": False, | |
| "format": "WAV", | |
| "modelVersion": "GEN2", | |
| "multiNativeLocale": multi_native_locale, | |
| "pitch": 0, | |
| "pronunciationDictionary": {}, | |
| "rate": 0, | |
| "sampleRate": 48000, | |
| "style": "Conversational", | |
| "text": text, | |
| "variation": 1, | |
| "voiceId": voice_id | |
| } | |
| r = requests.post("https://api.murf.ai/v1/speech/generate", headers=headers, json=payload) | |
| r.raise_for_status() | |
| j = r.json() | |
| audio_url = j.get("audioFile") | |
| if not audio_url: | |
| raise ValueError("No audioFile URL from Murf API.") | |
| audio_resp = requests.get(audio_url) | |
| audio_resp.raise_for_status() | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as wav_file: | |
| wav_file.write(audio_resp.content) | |
| wav_path = wav_file.name | |
| audio_seg = AudioSegment.from_file(wav_path, format="wav") | |
| audio_seg = effects.normalize(audio_seg) | |
| final_mp3_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name | |
| audio_seg.export(final_mp3_path, format="mp3") | |
| os.remove(wav_path) | |
| return final_mp3_path | |
| except Exception as e: | |
| print("[ERROR] TTS generation error:", e) | |
| raise ValueError(f"Error generating TTS audio: {str(e)}") | |
| def mix_with_bg_music(spoken: AudioSegment, custom_music_path=None) -> AudioSegment: | |
| """ | |
| Overlays 'spoken' with background music, offset by ~2s, volume lowered. | |
| """ | |
| if custom_music_path: | |
| music_path = custom_music_path | |
| else: | |
| music_path = "bg_music.mp3" | |
| try: | |
| bg_music = AudioSegment.from_file(music_path, format="mp3") | |
| except Exception as e: | |
| print("[ERROR] Failed to load background music:", e) | |
| return spoken | |
| bg_music = bg_music - 18.0 | |
| total_length_ms = len(spoken) + 2000 | |
| looped_music = AudioSegment.empty() | |
| while len(looped_music) < total_length_ms: | |
| looped_music += bg_music | |
| looped_music = looped_music[:total_length_ms] | |
| final_mix = looped_music.overlay(spoken, position=2000) | |
| return final_mix | |
| ############################################################################### | |
| # Q&A UTILITY (POST-PODCAST) | |
| ############################################################################### | |
| def call_groq_api_for_qa(system_prompt: str) -> str: | |
| """ | |
| Single-step Q&A for post-podcast. Usually short usage => minimal tokens. | |
| """ | |
| try: | |
| headers = { | |
| "Authorization": f"Bearer {os.environ.get('GROQ_API_KEY')}", | |
| "Content-Type": "application/json", | |
| "Accept": "application/json" | |
| } | |
| data = { | |
| "model": "deepseek-r1-distill-llama-70b", | |
| "messages": [{"role": "user", "content": system_prompt}], | |
| "max_tokens": 512, | |
| "temperature": 0.7 | |
| } | |
| r = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, data=json.dumps(data)) | |
| r.raise_for_status() | |
| return r.json()["choices"][0]["message"]["content"].strip() | |
| except Exception as e: | |
| print("[ERROR] Groq QA error:", e) | |
| fallback = {"speaker": "John", "text": "Sorry, I'm having trouble answering now."} | |
| return json.dumps(fallback) | |
| ############################################################################### | |
| # ITERATIVE MERGING HELPER FUNCTION (BATCH PROCESSING STRATEGY) | |
| ############################################################################### | |
| def iterative_merge_summaries(summaries: List[str], groq_client, references_text: str) -> str: | |
| """ | |
| Iteratively merge a list of summaries into one final report summary. | |
| This function groups summaries into batches whose total token count is below a set threshold, | |
| merges each batch, and then recursively merges the batch outputs until only one final summary remains. | |
| """ | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| max_merge_input_tokens = 2000 # Set a safe threshold for each merge call | |
| round_index = 1 | |
| while len(summaries) > 1: | |
| print(f"[LOG] Iterative merging round {round_index}: {len(summaries)} summaries to merge.") | |
| new_summaries = [] | |
| i = 0 | |
| while i < len(summaries): | |
| batch = [] | |
| batch_tokens = 0 | |
| # Group summaries until the token count exceeds threshold | |
| while i < len(summaries): | |
| summary = summaries[i] | |
| summary_tokens = len(tokenizer.encode(summary)) | |
| if batch_tokens + summary_tokens <= max_merge_input_tokens or not batch: | |
| batch.append(summary) | |
| batch_tokens += summary_tokens | |
| i += 1 | |
| else: | |
| break | |
| batch_text = "\n\n".join(batch) | |
| merge_prompt = f""" | |
| You are a specialized summarization engine. Merge the following summaries into one comprehensive summary. | |
| Summaries: | |
| {batch_text} | |
| References (if any): | |
| {references_text} | |
| Please output the merged summary. | |
| """ | |
| data = { | |
| "model": MODEL_COMBINATION, | |
| "messages": [{"role": "user", "content": merge_prompt}], | |
| "temperature": 0.3, | |
| "max_tokens": 4096 | |
| } | |
| merge_response = call_llm_with_retry(groq_client, **data) | |
| merged_batch = merge_response.choices[0].message.content.strip() | |
| merged_batch = re.sub(r"<think>.*?</think>", "", merged_batch, flags=re.DOTALL).strip() | |
| new_summaries.append(merged_batch) | |
| summaries = new_summaries | |
| round_index += 1 | |
| return summaries[0] | |
| ############################################################################### | |
| # LOW-CALL RESEARCH AGENT (Minimizing LLM Calls) | |
| ############################################################################### | |
| MODEL_SUMMARIZATION = "llama-3.1-8b-instant" | |
| MODEL_COMBINATION = "deepseek-r1-distill-llama-70b" | |
| def run_research_agent( | |
| topic: str, | |
| report_type: str = "research_report", | |
| max_results: int = 20 | |
| ) -> str: | |
| """ | |
| Low-Call approach: | |
| 1) Tavily search (up to 20 URLs). | |
| 2) Firecrawl scrape => combined text. | |
| 3) Use the full combined text without truncation. | |
| 4) Split into chunks (each 4500 tokens) => Summarize each chunk individually => summaries. | |
| 5) Iteratively merge the summaries into a final report. | |
| If the report output is incomplete, the model will output "CONTINUE" so that additional calls | |
| can be made to retrieve the rest of the report. | |
| => 2 or more total LLM calls (but no more than 10) to reduce the chance of rate limit errors. | |
| """ | |
| print(f"[LOG] Starting LOW-CALL research agent for topic: {topic}") | |
| try: | |
| # Step 1: Tavily search | |
| print("[LOG] Step 1: Searching with Tavily for relevant URLs (max_results=20).") | |
| tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY")) | |
| search_data = tavily_client.search(query=topic, max_results=max_results) | |
| search_results = search_data.get("results", []) | |
| print(f"[LOG] Tavily provided {len(search_results)} results. Proceeding to Step 2.") | |
| if not search_results: | |
| print("[LOG] No relevant search results found by Tavily.") | |
| return "No relevant search results found." | |
| references_list = [r["url"] for r in search_results if "url" in r] | |
| # Step 2: Firecrawl scraping | |
| print("[LOG] Step 2: Scraping each URL with Firecrawl.") | |
| combined_content = "" | |
| for result in search_results: | |
| url = result["url"] | |
| print(f"[LOG] Firecrawl scraping: {url}") | |
| headers = {'Authorization': f'Bearer {os.environ.get("FIRECRAWL_API_KEY")}'} | |
| payload = {"url": url, "formats": ["markdown"], "onlyMainContent": True} | |
| try: | |
| resp = requests.post("https://api.firecrawl.dev/v1/scrape", headers=headers, json=payload) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| if data.get("success") and "markdown" in data.get("data", {}): | |
| combined_content += data["data"]["markdown"] + "\n\n" | |
| else: | |
| print(f"[WARNING] Firecrawl scrape failed or no markdown for {url}: {data.get('error')}") | |
| except requests.RequestException as e: | |
| print(f"[ERROR] Firecrawl error for {url}: {e}") | |
| continue | |
| if not combined_content: | |
| print("[LOG] Could not retrieve content from any search results. Exiting.") | |
| return "Could not retrieve content from any of the search results." | |
| # Step 2.5: Input Sanitization - Remove any chain-of-thought markers. | |
| combined_content = re.sub(r"<think>.*?</think>", "", combined_content, flags=re.DOTALL) | |
| # Step 3: Use the full combined text without truncation. | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| total_tokens = len(tokenizer.encode(combined_content)) | |
| print(f"[LOG] Step 3: Using the full combined text without truncation. Total tokens: {total_tokens}") | |
| # Step 4: Splitting text into chunks (4500 tokens each) and summarizing each chunk. | |
| tokens = tokenizer.encode(combined_content) | |
| chunk_size = 4500 # Each chunk is 4500 tokens or less. | |
| total_chunks = math.ceil(len(tokens) / chunk_size) | |
| print(f"[LOG] Step 4: Splitting text into chunks of up to 4500 tokens. Total chunks: {total_chunks}") | |
| max_chunks = 10 # Allow up to 10 chunks. | |
| summaries = [] | |
| start = 0 | |
| chunk_index = 1 | |
| groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
| while start < len(tokens) and chunk_index <= max_chunks: | |
| end = min(start + chunk_size, len(tokens)) | |
| chunk_text = tokenizer.decode(tokens[start:end]) | |
| print(f"[LOG] Summarizing chunk {chunk_index} with ~{len(tokens[start:end])} tokens.") | |
| prompt = f""" | |
| You are a specialized summarization engine. Summarize the following text | |
| for a professional research report. Provide accurate details but do not | |
| include chain-of-thought or internal reasoning. Keep it concise, but | |
| include key data points and context: | |
| {chunk_text} | |
| """ | |
| data = { | |
| "model": MODEL_SUMMARIZATION, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "temperature": 0.2, | |
| "max_tokens": 768 | |
| } | |
| response = call_llm_with_retry(groq_client, **data) | |
| summary_text = response.choices[0].message.content.strip() | |
| summaries.append(summary_text) | |
| start = end | |
| chunk_index += 1 | |
| # Step 5: Iteratively merge the chunk summaries. | |
| print("[LOG] Step 5: Iteratively merging chunk summaries.") | |
| references_text = "\n".join(f"- {url}" for url in references_list) if references_list else "None" | |
| final_text = iterative_merge_summaries(summaries, groq_client, references_text) | |
| # --- NEW POST-PROCESSING STEP --- | |
| # Remove any lingering chain-of-thought markers. | |
| final_text = re.sub(r"<think>.*?</think>", "", final_text, flags=re.DOTALL).strip() | |
| # ------------------------------ | |
| # Step 6: PDF generation | |
| print("[LOG] Step 6: Generating final PDF from the merged text.") | |
| final_report = generate_report(final_text) | |
| print("[LOG] Done! Returning PDF from run_research_agent (low-call).") | |
| return final_report | |
| except Exception as e: | |
| print(f"[ERROR] Error in run_research_agent: {e}") | |
| return f"Sorry, encountered an error: {str(e)}" |