Spaces:
Sleeping
Sleeping
| import torch | |
| torch.set_float32_matmul_precision('high') | |
| from flask import Flask, send_from_directory, request, Response | |
| import os | |
| import base64 | |
| import numpy as np | |
| from inference import OmniInference | |
| import io | |
| app = Flask(__name__) | |
| # Initialize OmniInference | |
| try: | |
| print("Initializing OmniInference...") | |
| omni = OmniInference() | |
| print("OmniInference initialized successfully.") | |
| except Exception as e: | |
| print(f"Error initializing OmniInference: {str(e)}") | |
| raise | |
| def serve_html(): | |
| return send_from_directory('.', 'webui/omni_html_demo.html') | |
| def chat(): | |
| try: | |
| audio_data = request.json['audio'] | |
| if not audio_data: | |
| return "No audio data received", 400 | |
| # Check if the audio_data contains the expected base64 prefix | |
| if ',' in audio_data: | |
| audio_bytes = base64.b64decode(audio_data.split(',')[1]) | |
| else: | |
| audio_bytes = base64.b64decode(audio_data) | |
| # Save audio to a temporary file | |
| temp_audio_path = 'temp_audio.wav' | |
| with open(temp_audio_path, 'wb') as f: | |
| f.write(audio_bytes) | |
| # Generate response using OmniInference | |
| try: | |
| response_generator = omni.run_AT_batch_stream(temp_audio_path) | |
| # Concatenate all audio chunks | |
| all_audio = b'' | |
| for audio_chunk in response_generator: | |
| all_audio += audio_chunk | |
| # Clean up temporary file | |
| os.remove(temp_audio_path) | |
| return Response(all_audio, mimetype='audio/wav') | |
| except Exception as inner_e: | |
| print(f"Error in OmniInference processing: {str(inner_e)}") | |
| return f"An error occurred during audio processing: {str(inner_e)}", 500 | |
| finally: | |
| # Ensure temporary file is removed even if an error occurs | |
| if os.path.exists(temp_audio_path): | |
| os.remove(temp_audio_path) | |
| except Exception as e: | |
| print(f"Error in chat endpoint: {str(e)}") | |
| return f"An error occurred: {str(e)}", 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |