Spaces:
Paused
Paused
| from operator import itemgetter | |
| import os | |
| from datetime import datetime | |
| import uvicorn | |
| from typing import Any, Optional, Tuple, Dict, TypedDict | |
| from urllib import parse | |
| from uuid import uuid4 | |
| import logging | |
| from fastapi.logger import logger as fastapi_logger | |
| import sys | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi import APIRouter, Body, Request, status | |
| from pymongo import MongoClient | |
| from dotenv import dotenv_values | |
| from routes import router as api_router | |
| from contextlib import asynccontextmanager | |
| import requests | |
| from typing import List | |
| from datetime import date | |
| from mongodb.operations.calls import * | |
| from mongodb.operations.users import * | |
| from mongodb.models.calls import UserCall, UpdateCall | |
| # from mongodb.endpoints.calls import * | |
| from transformers import AutoProcessor, SeamlessM4Tv2Model | |
| # from seamless_communication.inference import Translator | |
| from Client import Client | |
| import numpy as np | |
| import torch | |
| import socketio | |
| # Configure logger | |
| gunicorn_error_logger = logging.getLogger("gunicorn.error") | |
| gunicorn_logger = logging.getLogger("gunicorn") | |
| uvicorn_access_logger = logging.getLogger("uvicorn.access") | |
| gunicorn_error_logger.propagate = True | |
| gunicorn_logger.propagate = True | |
| uvicorn_access_logger.propagate = True | |
| uvicorn_access_logger.handlers = gunicorn_error_logger.handlers | |
| fastapi_logger.handlers = gunicorn_error_logger.handlers | |
| # sio is the main socket.io entrypoint | |
| sio = socketio.AsyncServer( | |
| async_mode="asgi", | |
| cors_allowed_origins="*", | |
| logger=gunicorn_logger, | |
| engineio_logger=gunicorn_logger, | |
| ) | |
| # sio.logger.setLevel(logging.DEBUG) | |
| socketio_app = socketio.ASGIApp(sio) | |
| # app.mount("/", socketio_app) | |
| config = dotenv_values(".env") | |
| # Read connection string from environment vars | |
| # uri = os.environ['MONGODB_URI'] | |
| # Read connection string from .env file | |
| uri = config['MONGODB_URI'] | |
| # MongoDB Connection Lifespan Events | |
| async def lifespan(app: FastAPI): | |
| # startup logic | |
| app.mongodb_client = MongoClient(uri) | |
| app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db | |
| try: | |
| app.mongodb_client.admin.command('ping') | |
| print("MongoDB Connection Established...") | |
| except Exception as e: | |
| print(e) | |
| yield | |
| # shutdown logic | |
| print("Closing MongoDB Connection...") | |
| app.mongodb_client.close() | |
| app = FastAPI(lifespan=lifespan, logger=gunicorn_logger) | |
| # New CORS funcitonality | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # configured node app port | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.include_router(api_router) # include routers for user, calls and transcripts operations | |
| DEBUG = True | |
| ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock" | |
| TARGET_SAMPLING_RATE = 16000 | |
| MAX_BYTES_BUFFER = 960_000 | |
| print("") | |
| print("") | |
| print("=" * 18 + " Interpretalk is starting... " + "=" * 18) | |
| ############################################### | |
| # Configure socketio server | |
| ############################################### | |
| # TODO PM - change this to the actual path | |
| # seamless remnant code | |
| CLIENT_BUILD_PATH = "../streaming-react-app/dist/" | |
| static_files = { | |
| "/": CLIENT_BUILD_PATH, | |
| "/assets/seamless-db6a2555.svg": { | |
| "filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg", | |
| "content_type": "image/svg+xml", | |
| }, | |
| } | |
| # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| device = torch.device("cpu") | |
| processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large") | |
| # PM - hardcoding temporarily as my GPU doesnt have enough vram | |
| model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device) | |
| bytes_data = bytearray() | |
| model_name = "seamlessM4T_v2_large" | |
| vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs" | |
| clients = {} | |
| rooms = {} | |
| def get_collection_users(): | |
| return app.database["user_records"] | |
| def get_collection_calls(): | |
| return app.database["call_records"] | |
| def test(): | |
| return {"message": "Welcome to InterpreTalk!"} | |
| async def send_translated_text(client_id, username, original_text, translated_text, room_id): | |
| # print(rooms) # Debugging | |
| # print(clients) # Debugging | |
| data = { | |
| "author_id": str(client_id), | |
| "author_username": str(username), | |
| "original_text": str(original_text), | |
| "translated_text": str(translated_text), | |
| "timestamp": str(datetime.now()) | |
| } | |
| gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT") | |
| await sio.emit("translated_text", data, room=room_id) | |
| gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND") | |
| async def connect(sid, environ): | |
| print(f"📥 [event: connected] sid={sid}") | |
| query_params = dict(parse.parse_qsl(environ["QUERY_STRING"])) | |
| client_id = query_params.get("client_id") | |
| gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}") | |
| # get username to Client Object from DB | |
| username = find_name_from_id(get_collection_users(), client_id) | |
| # sid = socketid, client_id = client specific ID ,always the same for same user | |
| clients[sid] = Client(sid, client_id, username) | |
| print(clients[sid].username) | |
| gunicorn_logger.warning(f"Client connected: {sid}") | |
| gunicorn_logger.warning(clients) | |
| async def disconnect(sid): | |
| gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}") | |
| call_id = clients[sid].call_id | |
| user_id = clients[sid].client_id | |
| target_language = clients[sid].target_language | |
| clients.pop(sid, None) | |
| # Perform Key Term Extraction and summarisation | |
| try: | |
| # Get combined caption field for call record based on call_id | |
| key_terms = term_extraction(get_collection_calls(), call_id, user_id, target_language) | |
| # Perform summarisation based on target language | |
| summary_result = summarise(get_collection_calls(), call_id, user_id, target_language) | |
| except: | |
| gunicorn_logger.error(f"📤 [event: term_extraction/summarisation request error] sid={sid}, call={call_id}") | |
| async def target_language(sid, target_lang): | |
| gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}") | |
| clients[sid].target_language = target_lang | |
| async def call_user(sid, call_id): | |
| clients[sid].call_id = call_id | |
| gunicorn_logger.info(f"CALL {sid}: entering room {call_id}") | |
| rooms[call_id] = rooms.get(call_id, []) | |
| if sid not in rooms[call_id] and len(rooms[call_id]) < 2: | |
| rooms[call_id].append(sid) | |
| sio.enter_room(sid, call_id) | |
| else: | |
| gunicorn_logger.info(f"CALL {sid}: room {call_id} is full") | |
| # await sio.emit("room_full", room=call_id, to=sid) | |
| # BO - Get call id from dictionary created during socketio connection | |
| client_id = clients[sid].client_id | |
| gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}") | |
| # BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..) | |
| request_data = { | |
| "call_id": str(call_id), | |
| "caller_id": str(client_id), | |
| "creation_date": str(datetime.now()) | |
| } | |
| response = create_calls(get_collection_calls(), request_data) | |
| print(response) # BO - print created db call record | |
| async def audio_config(sid, sample_rate): | |
| clients[sid].original_sr = sample_rate | |
| async def answer_call(sid, call_id): | |
| clients[sid].call_id = call_id | |
| gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}") | |
| rooms[call_id] = rooms.get(call_id, []) | |
| if sid not in rooms[call_id] and len(rooms[call_id]) < 2: | |
| rooms[call_id].append(sid) | |
| sio.enter_room(sid, call_id) | |
| else: | |
| gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full") | |
| # await sio.emit("room_full", room=call_id, to=sid) | |
| # BO - Get call id from dictionary created during socketio connection | |
| client_id = clients[sid].client_id | |
| # BO -> Update Call Record with Callee field based on call_id | |
| gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}") | |
| # BO -> Create Call Record with callee_id field (None for callee, duration, terms..) | |
| request_data = { | |
| "callee_id": client_id | |
| } | |
| response = update_calls(get_collection_calls(), call_id, request_data) | |
| print(response) # BO - print created db call record | |
| async def incoming_audio(sid, data, call_id): | |
| try: | |
| clients[sid].add_bytes(data) | |
| if clients[sid].get_length() >= MAX_BYTES_BUFFER: | |
| gunicorn_logger.info('Buffer full, now outputting...') | |
| output_path = clients[sid].output_path | |
| resampled_audio = clients[sid].resample_and_clear() | |
| vad_result = clients[sid].vad_analyse(resampled_audio) | |
| # source lang is speakers tgt language 😃 | |
| src_lang = clients[sid].target_language | |
| if vad_result: | |
| gunicorn_logger.info('Speech detected, now processing audio.....') | |
| tgt_sid = next(id for id in rooms[call_id] if id != sid) | |
| tgt_lang = clients[tgt_sid].target_language | |
| # following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage | |
| output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt", sampling_rate=TARGET_SAMPLING_RATE).to(device) | |
| model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0] | |
| asr_text = processor.decode(model_output, skip_special_tokens=True) | |
| print(f"ASR TEXT = {asr_text}") | |
| # ASR TEXT => ORIGINAL TEXT | |
| if src_lang != tgt_lang: | |
| t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt").to(device) | |
| translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0] | |
| translated_text = processor.decode(translated_data, skip_special_tokens=True) | |
| print(f"TRANSLATED TEXT = {translated_text}") | |
| else: | |
| # PM - both users have same language selected, no need to translate | |
| translated_text = asr_text | |
| # PM - text_output is a list with 1 string | |
| await send_translated_text(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id) | |
| # BO -> send translated_text to mongodb as caption record update based on call_id | |
| await send_captions(clients[sid].client_id, clients[sid].username, asr_text, translated_text, call_id) | |
| except Exception as e: | |
| gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}") | |
| async def send_captions(client_id, username, original_text, translated_text, call_id): | |
| # BO -> Update Call Record with Callee field based on call_id | |
| print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}") | |
| data = { | |
| "author_id": str(client_id), | |
| "author_username": str(username), | |
| "original_text": str(original_text), | |
| "translated_text": str(translated_text), | |
| "timestamp": str(datetime.now()) | |
| } | |
| response = update_captions(get_collection_calls(), get_collection_users(), call_id, data) | |
| return response | |
| app.mount("/", socketio_app) | |
| if __name__ == '__main__': | |
| uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="info") | |
| # Running in Docker Container | |
| if __name__ != "__main__": | |
| fastapi_logger.setLevel(gunicorn_logger.level) | |
| else: | |
| fastapi_logger.setLevel(logging.DEBUG) | |