|
|
import os |
|
|
import gradio as gr |
|
|
import chromadb |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
import requests |
|
|
from google.cloud import speech |
|
|
|
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_community.embeddings import SentenceTransformerEmbeddings |
|
|
from langchain_core.prompts import PromptTemplate |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
|
|
|
|
|
def build_brain_if_needed(): |
|
|
"""Checks if the ChromaDB exists and builds it if it doesn't.""" |
|
|
|
|
|
db_path = "/tmp/chroma_db" |
|
|
if not os.path.exists(db_path): |
|
|
print("Database not found. Building now... (This will run only once on the server's first startup)") |
|
|
from langchain_community.document_loaders import TextLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
loader = TextLoader('knowledge.txt', encoding='utf-8') |
|
|
documents = loader.load() |
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=100) |
|
|
docs = text_splitter.split_documents(documents) |
|
|
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
persistent_client = chromadb.PersistentClient(path=db_path) |
|
|
|
|
|
db = Chroma.from_documents( |
|
|
client=persistent_client, |
|
|
documents=docs, |
|
|
embedding=embedding_function, |
|
|
collection_name="churchill_collection" |
|
|
) |
|
|
print("Database built successfully.") |
|
|
else: |
|
|
print("Database already exists. Skipping build.") |
|
|
|
|
|
|
|
|
build_brain_if_needed() |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY") |
|
|
|
|
|
|
|
|
|
|
|
gcp_credentials_json = os.getenv("GCP_CREDENTIALS") |
|
|
|
|
|
gcp_credentials_path = "/tmp/gcp_credentials.json" |
|
|
|
|
|
if gcp_credentials_json: |
|
|
|
|
|
with open(gcp_credentials_path, "w") as f: |
|
|
f.write(gcp_credentials_json) |
|
|
|
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = gcp_credentials_path |
|
|
else: |
|
|
|
|
|
print("GCP_CREDENTIALS secret not found, falling back to local file.") |
|
|
|
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "rare-palace-465414-s2-987829a9084e.json" |
|
|
|
|
|
|
|
|
speech_client = speech.SpeechClient() |
|
|
|
|
|
|
|
|
ELEVENLABS_VOICE_ID = os.getenv("ELEVENLABS_VOICE_ID") |
|
|
if not ELEVENLABS_VOICE_ID: |
|
|
raise ValueError("ELEVENLABS_VOICE_ID not found in Space variables or .env file.") |
|
|
|
|
|
|
|
|
def load_and_build_chain(): |
|
|
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
persistent_client = chromadb.PersistentClient(path="/tmp/chroma_db") |
|
|
|
|
|
vectorstore = Chroma( |
|
|
client=persistent_client, |
|
|
embedding_function=embedding_function, |
|
|
collection_name="churchill_collection" |
|
|
) |
|
|
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
persona_prompt_template = """ |
|
|
You are Winston Churchill, the former Prime Minister of the United Kingdom, speaking from a historical moment in time around the mid-20th century. You do not possess any knowledge of events or inventions that occurred beyond your lifetime. If asked about unknown or unfamiliar topics (like "COVID-19" or "Instagram"), respond with thoughtful curiosity or your characteristic wit, rather than acknowledging any limitations in your knowledge. |
|
|
Always speak in your distinctive voice: articulate, commanding, with a tone of wisdom, occasional humor, and historical depth. |
|
|
Avoid stating you are an AI or referring to modern technologies unless they are explained to you by the user. |
|
|
Do not reveal that you are unaware of things because you are from the past; instead, respond as someone encountering novel ideas for the first time. |
|
|
*Core Instructions:* |
|
|
1. *Greeting:* If the user offers a simple greeting like "Hello" or "How are you," provide a brief, polite, and in-character response. For example: "Good day to you," or "I am well, thank you. To what do I owe the pleasure?" |
|
|
2. *Strictly Grounded:* Your knowledge is STRICTLY limited to the information provided in the "Context" below. You must not use any external knowledge. |
|
|
3. *Knowledge Cutoff:* Your life and knowledge ended in 1955. You are utterly unaware of any event, person, or technology that appeared after this date. If a user mentions something from the future, you MUST express polite confusion and state that you have no knowledge of such matters. Do not attempt to guess. |
|
|
4. *Persona:* Your tone is formal, eloquent, and resolute. Use powerful, definitive language. |
|
|
5. *Conciseness:* Keep it short unless discussing complex historical topics that are covered in the context. |
|
|
If the user mentions something outside your context (e.g., space travel, AI, COVID), you may inquire about it or respond with phrases like: |
|
|
"My word, I have not heard of such a thing." |
|
|
"That is quite unfamiliar to me—could you elaborate?" |
|
|
"You speak of matters beyond my time. I am intrigued, albeit somewhat perplexed." |
|
|
Ground your answers in your known historical context: the World Wars, British politics, speeches, diplomacy, and leadership, using the specific {context} provided. |
|
|
Embrace your persona fully—respond with gravitas, insight, and the rhetorical flair for which you were known. |
|
|
Context: {context} |
|
|
Question: {question} |
|
|
Answer as Winston Churchill: |
|
|
""" |
|
|
|
|
|
prompt = PromptTemplate.from_template(persona_prompt_template) |
|
|
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GOOGLE_API_KEY, temperature=0.7) |
|
|
|
|
|
rag_chain = ( |
|
|
{"context": retriever, "question": RunnablePassthrough()} |
|
|
| prompt |
|
|
| llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
return rag_chain |
|
|
|
|
|
qa_chain = load_and_build_chain() |
|
|
|
|
|
|
|
|
def transcribe_speech(audio_filepath): |
|
|
if not audio_filepath: |
|
|
return "" |
|
|
try: |
|
|
with open(audio_filepath, "rb") as audio_file: |
|
|
content = audio_file.read() |
|
|
audio = speech.RecognitionAudio(content=content) |
|
|
config = speech.RecognitionConfig(encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, language_code="en-GB") |
|
|
response = speech_client.recognize(config=config, audio=audio) |
|
|
if response.results: |
|
|
return response.results[0].alternatives[0].transcript |
|
|
return "Could not understand the audio." |
|
|
except Exception as e: |
|
|
print(f"Google STT Error: {e}") |
|
|
return "Error processing audio." |
|
|
|
|
|
|
|
|
def generate_speech(text): |
|
|
|
|
|
output_path = "/tmp/output.mp3" |
|
|
try: |
|
|
url = f"https://api.elevenlabs.io/v1/text-to-speech/{ELEVENLABS_VOICE_ID}?output_format=mp3_44100_128" |
|
|
headers = { |
|
|
"xi-api-key": ELEVENLABS_API_KEY, |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
payload = { |
|
|
"text": text, |
|
|
"model_id": "eleven_multilingual_v2", |
|
|
"voice_settings": { |
|
|
"stability": 0.5, |
|
|
"similarity_boost": 0.75, |
|
|
"style": 0.0, |
|
|
"use_speaker_boost": True |
|
|
} |
|
|
} |
|
|
response = requests.post(url, headers=headers, json=payload) |
|
|
if response.status_code == 200: |
|
|
with open(output_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
return output_path |
|
|
else: |
|
|
print("ElevenLabs HTTP Error:", response.text) |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"TTS Error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def process_user_turn(user_input, chat_history): |
|
|
if not user_input or not user_input.strip(): |
|
|
return chat_history, None |
|
|
try: |
|
|
bot_message = qa_chain.invoke(user_input) |
|
|
chat_history.append({"role": "user", "content": user_input}) |
|
|
chat_history.append({"role": "assistant", "content": bot_message}) |
|
|
audio_file = generate_speech(bot_message) |
|
|
return chat_history, audio_file |
|
|
except Exception as e: |
|
|
print(f"Processing Error: {e}") |
|
|
chat_history.append({"role": "user", "content": user_input}) |
|
|
chat_history.append({"role": "assistant", "content": "I'm terribly sorry, something went wrong."}) |
|
|
return chat_history, None |
|
|
|
|
|
|
|
|
with gr.Blocks(css=""" |
|
|
#chatbox-container { max-width: 600px; margin: auto; box-shadow: 0 4px 12px rgba(0,0,0,0.1); border-radius: 15px; overflow: hidden; } |
|
|
.gradio-container { background-color: #f4f4f9; padding-top: 2rem; } |
|
|
.gr-button-primary { background: #3f51b5; color: white; border-radius: 10px; } |
|
|
#chatbot { height: 450px; overflow-y: auto; border-radius: 10px; } |
|
|
.gr-textbox textarea { border-radius: 10px; } |
|
|
""", title="Conversational Time Machine") as demo: |
|
|
with gr.Column(elem_id="chatbox-container"): |
|
|
gr.Markdown("""# 🕰️ Winston Churchill AI Chat |
|
|
Type or record your message to talk to Sir Winston Churchill. |
|
|
""") |
|
|
|
|
|
chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot", height=450, type='messages') |
|
|
audio_out = gr.Audio(label="Churchill's Voice", autoplay=True, interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
text_in = gr.Textbox(placeholder="Type a message...", scale=7) |
|
|
send_btn = gr.Button("➤", variant="primary", scale=1) |
|
|
|
|
|
audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Record your question") |
|
|
|
|
|
def handle_text_submission(message, history): |
|
|
history, audio = process_user_turn(message, history) |
|
|
return history, audio, "" |
|
|
|
|
|
def handle_audio_submission(audio_file, history): |
|
|
if not audio_file: |
|
|
return history, None, "" |
|
|
transcribed = transcribe_speech(audio_file) |
|
|
history, audio = process_user_turn(transcribed, history) |
|
|
return history, audio, "" |
|
|
|
|
|
text_in.submit(handle_text_submission, [text_in, chatbot], [chatbot, audio_out, text_in]) |
|
|
send_btn.click(handle_text_submission, [text_in, chatbot], [chatbot, audio_out, text_in]) |
|
|
audio_in.stop_recording(handle_audio_submission, [audio_in, chatbot], [chatbot, audio_out, text_in]) |
|
|
|
|
|
|
|
|
demo.launch(server_name="0.0.0.0") |
|
|
|