Spaces:
Sleeping
Sleeping
| ########################################## | |
| # Step 0: Essential imports | |
| ########################################## | |
| import streamlit as st # Web interface | |
| from transformers import ( # AI components: emotion analysis, text-to-speech, text generation | |
| pipeline, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan, | |
| AutoModelForCausalLM, | |
| AutoTokenizer | |
| ) | |
| from datasets import load_dataset # To load speaker embeddings dataset | |
| import torch # For tensor operations | |
| import soundfile as sf # For audio file writing | |
| import sentencepiece # Required for SpeechT5Processor tokenization | |
| ########################################## | |
| # Initial configuration (MUST BE FIRST) | |
| ########################################## | |
| st.set_page_config( # Set page configuration | |
| page_title="Just Comment", | |
| page_icon="💬", | |
| layout="centered" | |
| ) | |
| ########################################## | |
| # Optimized model loader with caching | |
| ########################################## | |
| def _load_components(): | |
| """Load and cache all models with hardware optimization.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # Detect available device | |
| # Load emotion classifier (fast; input truncated) | |
| emotion_pipe = pipeline( | |
| "text-classification", | |
| model="Thea231/jhartmann_emotion_finetuning", | |
| device=device, | |
| truncation=True | |
| ) | |
| # Load text generation components with conditional device mapping | |
| text_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B") | |
| if device == "cuda": | |
| text_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen1.5-0.5B", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| else: | |
| text_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen1.5-0.5B", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| # Load TTS components | |
| tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained( | |
| "microsoft/speecht5_tts", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| tts_vocoder = SpeechT5HifiGan.from_pretrained( | |
| "microsoft/speecht5_hifigan", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| # Load a pre-trained speaker embedding (neutral voice) | |
| speaker_emb = torch.tensor( | |
| load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
| ).unsqueeze(0).to(device) | |
| return { | |
| "emotion": emotion_pipe, | |
| "text_model": text_model, | |
| "text_tokenizer": text_tokenizer, | |
| "tts_processor": tts_processor, | |
| "tts_model": tts_model, | |
| "tts_vocoder": tts_vocoder, | |
| "speaker_emb": speaker_emb, | |
| "device": device | |
| } | |
| ########################################## | |
| # User interface components | |
| ########################################## | |
| def _show_interface(): | |
| """Render input interface.""" | |
| st.title("🚀 Just Comment") # Display title with rocket emoji | |
| st.markdown("### I'm listening to you, my friend~") # Display friendly subtitle | |
| return st.text_area( # Return user comment input | |
| "📝 Enter your comment:", | |
| placeholder="Share your thoughts...", | |
| height=150, | |
| key="input" | |
| ) | |
| ########################################## | |
| # Core processing functions | |
| ########################################## | |
| def _fast_emotion(text, analyzer): | |
| """Rapidly detect dominant emotion using a truncated input.""" | |
| result = analyzer(text[:256], return_all_scores=True)[0] # Analyze first 256 characters | |
| valid_emotions = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] | |
| return max( | |
| (e for e in result if e['label'].lower() in valid_emotions), | |
| key=lambda x: x['score'], | |
| default={'label': 'neutral', 'score': 0} | |
| ) | |
| def _build_prompt(text, emotion): | |
| """Build a continuous prompt (1–3 sentences) based on detected emotion.""" | |
| templates = { | |
| "sadness": "I sensed sadness in your comment: {text}. We are sorry and ready to support you.", | |
| "joy": "Your comment shows joy: {text}. Thank you for your positive feedback; we are excited to serve you better.", | |
| "love": "Your comment expresses love: {text}. We appreciate your heartfelt words and value our connection.", | |
| "anger": "I understand your comment reflects anger: {text}. Please accept our sincere apologies as we address your concerns.", | |
| "fear": "It seems you feel fear: {text}. Rest assured, your safety and satisfaction are our top priorities.", | |
| "surprise": "Your comment exudes surprise: {text}. We are pleased by your experience and will strive to exceed your expectations.", | |
| "neutral": "Thank you for your comment: {text}. We are committed to providing you with excellent service." | |
| } | |
| # Use the template corresponding to the detected emotion (default to neutral) | |
| return templates.get(emotion.lower(), templates["neutral"]).format(text=text[:200]) | |
| def _generate_response(text, models): | |
| """Generate a response by combining emotion detection and text generation.""" | |
| # Detect emotion quickly | |
| detected_emotion = _fast_emotion(text, models["emotion"]) | |
| # Build prompt based on the detected emotion in a continuous format | |
| prompt = _build_prompt(text, detected_emotion["label"]) | |
| print(f"Generated prompt: {prompt}") # Debug print with f-string | |
| # Tokenize and generate response using the Qwen model | |
| inputs = models["text_tokenizer"]( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=100, | |
| truncation=True | |
| ).to(models["device"]) | |
| output = models["text_model"].generate( | |
| inputs.input_ids, | |
| max_new_tokens=120, # Constrain length for 50-200 tokens response | |
| min_length=50, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=models["text_tokenizer"].eos_token_id | |
| ) | |
| input_len = inputs.input_ids.shape[1] # Length of prompt tokens | |
| full_text = models["text_tokenizer"].decode(output[0], skip_special_tokens=True) | |
| # Extract only the generated response portion (after any "Response:" marker if present) | |
| response = full_text.split("Response:")[-1].strip() | |
| print(f"Generated response: {response}") # Debug print with f-string | |
| return response[:200] # Return response truncated to around 200 characters as an approximation | |
| def _text_to_speech(text, models): | |
| """Convert the generated response text to speech and return the audio file path.""" | |
| inputs = models["tts_processor"]( | |
| text=text[:150], # Limit TTS input to 150 characters for speed | |
| return_tensors="pt" | |
| ).to(models["device"]) | |
| with torch.inference_mode(): # Accelerate inference | |
| spectrogram = models["tts_model"].generate_speech( | |
| inputs["input_ids"], | |
| models["speaker_emb"] | |
| ) | |
| audio = models["tts_vocoder"](spectrogram) | |
| sf.write("output.wav", audio.cpu().numpy(), 16000) # Save the audio file with 16kHz sample rate | |
| return "output.wav" # Return the path to the audio file | |
| ########################################## | |
| # Main application flow | |
| ########################################## | |
| def main(): | |
| """Primary execution controller.""" | |
| models = _load_components() # Load all necessary models and components | |
| user_input = _show_interface() # Render the input interface and get user comment | |
| if user_input: # Proceed only if a comment is provided | |
| with st.spinner("🔍 Generating response..."): | |
| generated_response = _generate_response(user_input, models) | |
| st.subheader("📄 Response") | |
| st.markdown( | |
| f"<p style='color:#3498DB; font-size:20px;'>{generated_response}</p>", | |
| unsafe_allow_html=True | |
| ) # Display the generated response in styled format | |
| with st.spinner("🔊 Synthesizing audio..."): | |
| audio_file = _text_to_speech(generated_response, models) | |
| st.audio(audio_file, format="audio/wav", start_time=0) # Embed auto-playing audio player | |
| print(f"Final generated response: {generated_response}") # Debug print with f-string | |
| if __name__ == "__main__": | |
| main() # Call the main function | |