Spaces:
Sleeping
Sleeping
File size: 6,212 Bytes
fbe7105 3f5192e 9da1b27 fbe7105 9da1b27 fbe7105 913a039 fbe7105 913a039 fbe7105 9da1b27 fbe7105 3f5192e fbe7105 3f5192e fbe7105 3f5192e fbe7105 3f5192e fbe7105 9da1b27 8d8ce1a 43344fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from fastapi import FastAPI, HTTPException
from fastapi import UploadFile, File
from gtts import gTTS
from io import BytesIO
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import MarianMTModel, MarianTokenizer
import speech_recognition as sr
# import pyttsx3
import sounddevice as sd
from scipy.io.wavfile import write
import tempfile
import os
# import uvicorn
# Initialize FastAPI app
app = FastAPI()
# CORS configuration
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"], # Allow all origins for development; adjust for production
# allow_credentials=True,
# allow_methods=["*"], # Allow all HTTP methods
# allow_headers=["*"], # Allow all headers
# )
# Initialize TTS engine for speaking the translated text
# engine = pyttsx3.init()
# Supported languages dictionary
supported_languages = {
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
"it": "Italian", "ru": "Russian", "zh": "Chinese", "ar": "Arabic",
"hi": "Hindi", "ja": "Japanese", "ko": "Korean", "pt": "Portuguese",
"nl": "Dutch", "sv": "Swedish", "pl": "Polish", "tr": "Turkish",
"vi": "Vietnamese", "th": "Thai", "he": "Hebrew", "id": "Indonesian"
}
# Model for input data validation
class TranslationRequest(BaseModel):
src_lang: str
tgt_lang: str
text: str
class SpeakRequest(BaseModel):
text: str
def load_model(src_lang, tgt_lang):
"""Load the appropriate translation model."""
model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
try:
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return model, tokenizer
except Exception as e:
print(f"Model loading error: {e}") # Log error
raise
def translate_text(text, model, tokenizer):
"""Translate input text."""
inputs = tokenizer(text, return_tensors="pt", padding=True)
translated = model.generate(**inputs)
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
return translated_text
@app.post("/translate")
async def translate(request: TranslationRequest):
"""Translate text from source to target language."""
if request.src_lang not in supported_languages or request.tgt_lang not in supported_languages:
raise HTTPException(status_code=400, detail="Unsupported language.")
try:
print(f"Translating text: {request.text} from {request.src_lang} to {request.tgt_lang}") # Debug info
model, tokenizer = load_model(request.src_lang, request.tgt_lang)
translated_text = translate_text(request.text, model, tokenizer)
print(f"Translated text: {translated_text}") # Debug info
return {"translated_text": translated_text}
except Exception as e:
print(f"Error: {e}") # Log error
raise HTTPException(status_code=500, detail=str(e))
@app.post("/recognize")
async def recognize(language: str = 'en', file: UploadFile = File(...)):
"""Recognize speech from uploaded WAV file"""
try:
recognizer = sr.Recognizer()
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(await file.read())
temp_audio_path = temp_audio.name
with sr.AudioFile(temp_audio_path) as source:
audio = recognizer.record(source)
recognized_text = recognizer.recognize_google(audio, language=language)
os.remove(temp_audio_path)
return {"recognized_text": recognized_text}
except sr.UnknownValueError:
raise HTTPException(status_code=400, detail="Could not understand the audio")
except sr.RequestError as e:
raise HTTPException(status_code=500, detail=f"Recognition error: {e}")
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# @app.post("/recognize")
# async def recognize(language: str = 'en'):
# """Capture speech and recognize it using sounddevice."""
# duration = 10 # Duration of recording in seconds
# samplerate = 16000 # Sample rate for recording
# try:
# print("Recording audio...") # Debug info
# # Record audio using sounddevice
# recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='int16')
# sd.wait() # Wait until recording is finished
# # Save audio to a temporary file
# with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
# write(temp_audio_file.name, samplerate, recording)
# temp_audio_path = temp_audio_file.name
# # Recognize speech using speech_recognition
# recognizer = sr.Recognizer()
# with sr.AudioFile(temp_audio_path) as source:
# audio = recognizer.record(source)
# recognized_text = recognizer.recognize_google(audio, language=language)
# # Clean up temporary audio file
# os.remove(temp_audio_path)
# return {"recognized_text": recognized_text}
# except sr.UnknownValueError:
# raise HTTPException(status_code=400, detail="Could not understand the audio")
# except sr.RequestError as e:
# raise HTTPException(status_code=500, detail=f"Recognition error: {e}")
# except Exception as e:
# print(f"Error: {e}") # Log error
# raise HTTPException(status_code=500, detail=str(e))
@app.get("/speak")
def speak(text: str):
tts = gTTS(text)
mp3_fp = BytesIO()
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
return StreamingResponse(mp3_fp, media_type="audio/mpeg")
# @app.post("/speak")
# async def speak(request: SpeakRequest):
# """Speak out a given text."""
# if not request.text:
# raise HTTPException(status_code=400, detail="No text provided.")
# try:
# engine.say(request.text)
# engine.runAndWait()
# return {"message": "Text spoken successfully"}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
# if __name__ == "__main__":
# uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |