Spaces:
Running
Running
| import os | |
| import io | |
| import json | |
| import base64 | |
| import requests | |
| import gradio as gr | |
| from bs4 import BeautifulSoup | |
| from groq import Groq | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # --- API KEYS --- | |
| BRIGHTDATA_API_KEY = os.getenv("BRIGHTDATA_API_KEY") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| HF_API_KEY = os.getenv("HF_API_KEY") # Optional: Hugging Face token for image gen | |
| # --- Clients --- | |
| client = Groq(api_key=GROQ_API_KEY) | |
| openai_client = None | |
| if OPENAI_API_KEY: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| # --- Persistent Storage --- | |
| HISTORY_FILE = "chat_history.json" | |
| if os.path.exists(HISTORY_FILE): | |
| try: | |
| with open(HISTORY_FILE, "r") as f: | |
| conversation_history = json.load(f) | |
| if not isinstance(conversation_history, list): | |
| conversation_history = [] | |
| except (json.JSONDecodeError, Exception): | |
| conversation_history = [] | |
| else: | |
| conversation_history = [] | |
| # ---------------------- | |
| # LLM Wrapper | |
| # ---------------------- | |
| def ask_llm(query, context=None): | |
| system_prompt = """ | |
| You are a helpful AI assistant. | |
| Use ONLY the provided context and conversation history to answer the question. | |
| If the answer is not found in the context, respond clearly that you don't know based on the provided info. | |
| """ | |
| messages = [{"role": "system", "content": system_prompt}] | |
| if context: | |
| messages.append({"role": "system", "content": f"CONTEXT:\n{context}"}) | |
| messages.extend(conversation_history) | |
| messages.append({"role": "user", "content": query}) | |
| try: | |
| response = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=messages, | |
| temperature=0.3 | |
| ) | |
| answer = response.choices[0].message.content | |
| conversation_history.append({"role": "user", "content": query}) | |
| conversation_history.append({"role": "assistant", "content": answer}) | |
| with open(HISTORY_FILE, "w") as f: | |
| json.dump(conversation_history, f, indent=2) | |
| return answer | |
| except Exception as e: | |
| return f"Error communicating with LLM: {str(e)}" | |
| # ---------------------- | |
| # Image Generation (V5) | |
| # ---------------------- | |
| def should_generate_image(query: str, answer: str) -> bool: | |
| """ | |
| Uses a lightweight LLM call to decide if a visual would enhance the response. | |
| Returns True if image generation is warranted. | |
| """ | |
| check_prompt = f""" | |
| User asked: "{query}" | |
| AI answered: "{answer[:500]}" | |
| Should an image be generated to visually illustrate this response? | |
| Answer ONLY with "yes" or "no". Consider "yes" for requests involving: | |
| - Visual explanations of concepts (diagrams, architecture, structure) | |
| - Descriptions of places, objects, scenes, or people | |
| - "Show me", "illustrate", "visualize", "draw", "what does X look like" | |
| - Scientific, historical, or technical topics that benefit from visuals | |
| - Creative or imaginative requests | |
| Answer "no" for purely factual, conversational, or data-only responses. | |
| """ | |
| try: | |
| response = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[{"role": "user", "content": check_prompt}], | |
| temperature=0.0, | |
| max_tokens=5 | |
| ) | |
| decision = response.choices[0].message.content.strip().lower() | |
| return decision.startswith("yes") | |
| except Exception: | |
| return False | |
| def build_image_prompt(query: str) -> str: | |
| """ | |
| Uses LLM to produce a concise, descriptive image prompt from the user's query. | |
| """ | |
| prompt_gen = f""" | |
| Convert this user request into a short, vivid image generation prompt (max 30 words). | |
| Focus on visual elements only. Be descriptive and specific. | |
| User request: "{query}" | |
| Image prompt:""" | |
| try: | |
| response = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[{"role": "user", "content": prompt_gen}], | |
| temperature=0.5, | |
| max_tokens=60 | |
| ) | |
| return response.choices[0].message.content.strip().strip('"').strip("'") | |
| except Exception: | |
| return query # fallback to raw query | |
| def generate_image_openai(image_prompt: str): | |
| """Generate image using DALL-E 3 via OpenAI.""" | |
| if not openai_client: | |
| return None, "OpenAI key not configured." | |
| try: | |
| response = openai_client.images.generate( | |
| model="dall-e-3", | |
| prompt=image_prompt, | |
| size="1024x1024", | |
| quality="standard", | |
| n=1 | |
| ) | |
| image_url = response.data[0].url | |
| # Retry download up to 3 times with longer timeout | |
| for attempt in range(3): | |
| try: | |
| img_response = requests.get(image_url, timeout=60) | |
| img_response.raise_for_status() | |
| return img_response.content, None | |
| except requests.exceptions.Timeout: | |
| if attempt == 2: | |
| return None, "Image download timed out after 3 attempts. Try again." | |
| continue | |
| except Exception as e: | |
| return None, f"Image download error: {str(e)}" | |
| except Exception as e: | |
| return None, f"DALL-E 3 error: {str(e)}" | |
| def generate_image_hf(image_prompt: str, model: str = "black-forest-labs/FLUX.1-schnell"): | |
| """Generate image using Hugging Face Inference API (free open-source models).""" | |
| api_url = f"https://api-inference.huggingface.co/models/{model}" | |
| headers = {} | |
| if HF_API_KEY: | |
| headers["Authorization"] = f"Bearer {HF_API_KEY}" | |
| try: | |
| response = requests.post( | |
| api_url, | |
| headers=headers, | |
| json={"inputs": image_prompt}, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| return response.content, None | |
| else: | |
| return None, f"HF API error {response.status_code}: {response.text[:200]}" | |
| except Exception as e: | |
| return None, f"HF request error: {str(e)}" | |
| def generate_image(image_prompt: str, use_openai: bool = False): | |
| """ | |
| Main image generation dispatcher. | |
| Tries OpenAI first if selected, falls back to HF FLUX.1-schnell. | |
| Returns (PIL Image or bytes, error_message) | |
| """ | |
| if use_openai and openai_client: | |
| img_bytes, err = generate_image_openai(image_prompt) | |
| else: | |
| img_bytes, err = generate_image_hf(image_prompt) | |
| if img_bytes is None and openai_client: | |
| # Fallback to DALL-E if HF fails | |
| img_bytes, err = generate_image_openai(image_prompt) | |
| if img_bytes: | |
| from PIL import Image | |
| return Image.open(io.BytesIO(img_bytes)), None | |
| return None, err | |
| def maybe_generate_image(query: str, answer: str, image_backend: str): | |
| """ | |
| Full pipeline: decide → build prompt → generate → return (image, prompt_used, status). | |
| """ | |
| if not should_generate_image(query, answer): | |
| return None, "", "No image needed for this response." | |
| image_prompt = build_image_prompt(query) | |
| use_openai = (image_backend == "DALL-E 3 (OpenAI)") | |
| image, err = generate_image(image_prompt, use_openai=use_openai) | |
| if image: | |
| return image, image_prompt, f"✅ Image generated using: {image_backend}" | |
| else: | |
| return None, image_prompt, f"⚠️ Image generation failed: {err}" | |
| # ---------------------- | |
| # Website Scraper | |
| # ---------------------- | |
| def scrape_website(url, question, image_backend): | |
| try: | |
| headers = {"Authorization": f"Bearer {BRIGHTDATA_API_KEY}"} | |
| payload = {"zone": "web_unlocker1", "url": url, "format": "raw"} | |
| response = requests.post( | |
| "https://api.brightdata.com/request", | |
| headers=headers, | |
| json=payload, | |
| timeout=60 | |
| ) | |
| if response.status_code != 200: | |
| return f"Bright Data Error: {response.status_code}", None, "", "" | |
| soup = BeautifulSoup(response.text, "html.parser") | |
| text = soup.get_text(separator=" ", strip=True) | |
| if not text: | |
| return "⚠️ Could not extract content from the website.", None, "", "" | |
| answer = ask_llm(question, context=text[:12000]) | |
| image, img_prompt, img_status = maybe_generate_image(question, answer, image_backend) | |
| return answer, image, img_prompt, img_status | |
| except Exception as e: | |
| return f"Error scraping website: {str(e)}", None, "", "" | |
| # ---------------------- | |
| # YouTube Transcript Q&A | |
| # ---------------------- | |
| def youtube_qa(video_id, question, image_backend): | |
| try: | |
| transcript = YouTubeTranscriptApi.get_transcript(video_id) | |
| full_text = " ".join([entry["text"] for entry in transcript]) | |
| if not full_text.strip(): | |
| return "⚠️ No transcript text found.", None, "", "" | |
| answer = ask_llm(question, context=full_text[:12000]) | |
| image, img_prompt, img_status = maybe_generate_image(question, answer, image_backend) | |
| return answer, image, img_prompt, img_status | |
| except Exception: | |
| return "❌ Could not retrieve transcript. Invalid video ID or no transcript available.", None, "", "" | |
| # ---------------------- | |
| # Voice Chat (STT + TTS) | |
| # ---------------------- | |
| def voice_chat(audio_file, image_backend): | |
| if not audio_file: | |
| return "", "⚠️ No audio provided.", None, None, "", "" | |
| try: | |
| with open(audio_file, "rb") as f: | |
| transcription = client.audio.transcriptions.create( | |
| file=f, | |
| model="whisper-large-v3" | |
| ) | |
| user_text = transcription.text | |
| except Exception as e: | |
| return "", f"❌ Could not transcribe audio: {e}", None, None, "", "" | |
| answer_text = ask_llm(user_text) | |
| audio_path = None | |
| try: | |
| if openai_client: | |
| tts_response = openai_client.audio.speech.create( | |
| model="tts-1", | |
| voice="alloy", | |
| input=answer_text[:4096] | |
| ) | |
| os.makedirs("temp_audio", exist_ok=True) | |
| audio_path = "temp_audio/output.mp3" | |
| with open(audio_path, "wb") as f: | |
| f.write(tts_response.content) | |
| except Exception as e: | |
| answer_text += f"\n\n❌ Could not generate audio: {e}" | |
| image, img_prompt, img_status = maybe_generate_image(user_text, answer_text, image_backend) | |
| return user_text, answer_text, audio_path, image, img_prompt, img_status | |
| # ---------------------- | |
| # Text Chat (New Tab - V5) | |
| # ---------------------- | |
| def text_chat(question, image_backend): | |
| if not question.strip(): | |
| return "⚠️ Please enter a question.", None, "", "" | |
| answer = ask_llm(question) | |
| image, img_prompt, img_status = maybe_generate_image(question, answer, image_backend) | |
| return answer, image, img_prompt, img_status | |
| # ---------------------- | |
| # Gradio Interface | |
| # ---------------------- | |
| IMAGE_BACKENDS = [ | |
| "FLUX.1-schnell (HuggingFace - Free)", | |
| "Stable Diffusion XL (HuggingFace - Free)", | |
| "DALL-E 3 (OpenAI)", | |
| ] | |
| with gr.Blocks(title="Multimodal AI Assistant v5") as demo: | |
| gr.Markdown("# 🤖 Multimodal AI Assistant — Version 5") | |
| gr.Markdown("Now with **on-demand image generation** 🎨 powered by FLUX.1, SDXL, or DALL-E 3.") | |
| with gr.Tabs(): | |
| # Tab 0: Text Chat + Image | |
| with gr.Tab("💬 Text Chat"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| tc_question = gr.Textbox(label="Your Question", placeholder="Ask anything… e.g. 'Explain the solar system visually'") | |
| tc_backend = gr.Dropdown(IMAGE_BACKENDS, value=IMAGE_BACKENDS[0], label="Image Generation Backend") | |
| tc_btn = gr.Button("Ask", variant="primary") | |
| with gr.Column(scale=3): | |
| tc_answer = gr.Textbox(label="Answer", lines=6) | |
| tc_image = gr.Image(label="Generated Image (if applicable)", type="pil") | |
| tc_img_prompt = gr.Textbox(label="Image Prompt Used", interactive=False) | |
| tc_img_status = gr.Textbox(label="Image Status", interactive=False) | |
| tc_btn.click( | |
| text_chat, | |
| inputs=[tc_question, tc_backend], | |
| outputs=[tc_answer, tc_image, tc_img_prompt, tc_img_status] | |
| ) | |
| # Tab 1: Website Q&A | |
| with gr.Tab("🌐 Website Q&A"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| url_input = gr.Textbox(label="Website URL") | |
| website_question = gr.Textbox(label="Ask a Question") | |
| ws_backend = gr.Dropdown(IMAGE_BACKENDS, value=IMAGE_BACKENDS[0], label="Image Backend") | |
| website_btn = gr.Button("Ask", variant="primary") | |
| with gr.Column(scale=3): | |
| website_output = gr.Textbox(label="Answer", lines=6) | |
| ws_image = gr.Image(label="Generated Image", type="pil") | |
| ws_img_prompt = gr.Textbox(label="Image Prompt Used", interactive=False) | |
| ws_img_status = gr.Textbox(label="Image Status", interactive=False) | |
| website_btn.click( | |
| scrape_website, | |
| inputs=[url_input, website_question, ws_backend], | |
| outputs=[website_output, ws_image, ws_img_prompt, ws_img_status] | |
| ) | |
| # Tab 2: YouTube Transcript Q&A | |
| with gr.Tab("🎥 YouTube Q&A"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| video_id_input = gr.Textbox(label="YouTube Video ID") | |
| youtube_question = gr.Textbox(label="Ask a Question") | |
| yt_backend = gr.Dropdown(IMAGE_BACKENDS, value=IMAGE_BACKENDS[0], label="Image Backend") | |
| youtube_btn = gr.Button("Ask", variant="primary") | |
| with gr.Column(scale=3): | |
| youtube_output = gr.Textbox(label="Answer", lines=6) | |
| yt_image = gr.Image(label="Generated Image", type="pil") | |
| yt_img_prompt = gr.Textbox(label="Image Prompt Used", interactive=False) | |
| yt_img_status = gr.Textbox(label="Image Status", interactive=False) | |
| youtube_btn.click( | |
| youtube_qa, | |
| inputs=[video_id_input, youtube_question, yt_backend], | |
| outputs=[youtube_output, yt_image, yt_img_prompt, yt_img_status] | |
| ) | |
| # Tab 3: Voice Chat | |
| with gr.Tab("🎤 Voice Chat"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Record your question") | |
| vc_backend = gr.Dropdown(IMAGE_BACKENDS, value=IMAGE_BACKENDS[0], label="Image Backend") | |
| voice_btn = gr.Button("Ask", variant="primary") | |
| with gr.Column(scale=3): | |
| voice_text_output = gr.Textbox(label="Transcribed Text") | |
| voice_answer_output = gr.Textbox(label="AI Answer", lines=5) | |
| voice_audio_output = gr.Audio(label="AI Voice Response", autoplay=True) | |
| vc_image = gr.Image(label="Generated Image", type="pil") | |
| vc_img_prompt = gr.Textbox(label="Image Prompt Used", interactive=False) | |
| vc_img_status = gr.Textbox(label="Image Status", interactive=False) | |
| voice_btn.click( | |
| voice_chat, | |
| inputs=[audio_input, vc_backend], | |
| outputs=[voice_text_output, voice_answer_output, voice_audio_output, vc_image, vc_img_prompt, vc_img_status] | |
| ) | |
| demo.launch() |