1MR commited on
Commit
fbe7105
·
verified ·
1 Parent(s): 33c30f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from transformers import MarianMTModel, MarianTokenizer
5
+ import speech_recognition as sr
6
+ import pyttsx3
7
+ import sounddevice as sd
8
+ from scipy.io.wavfile import write
9
+ import tempfile
10
+ import os
11
+
12
+ # Initialize FastAPI app
13
+ app = FastAPI()
14
+
15
+ # CORS configuration
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"], # Allow all origins for development; adjust for production
19
+ allow_credentials=True,
20
+ allow_methods=["*"], # Allow all HTTP methods
21
+ allow_headers=["*"], # Allow all headers
22
+ )
23
+
24
+ # Initialize TTS engine for speaking the translated text
25
+ engine = pyttsx3.init()
26
+
27
+ # Supported languages dictionary
28
+ supported_languages = {
29
+ "en": "English", "fr": "French", "es": "Spanish", "de": "German",
30
+ "it": "Italian", "ru": "Russian", "zh": "Chinese", "ar": "Arabic",
31
+ "hi": "Hindi", "ja": "Japanese", "ko": "Korean", "pt": "Portuguese",
32
+ "nl": "Dutch", "sv": "Swedish", "pl": "Polish", "tr": "Turkish",
33
+ "vi": "Vietnamese", "th": "Thai", "he": "Hebrew", "id": "Indonesian"
34
+ }
35
+
36
+ # Model for input data validation
37
+ class TranslationRequest(BaseModel):
38
+ src_lang: str
39
+ tgt_lang: str
40
+ text: str
41
+
42
+ class SpeakRequest(BaseModel):
43
+ text: str
44
+
45
+ def load_model(src_lang, tgt_lang):
46
+ """Load the appropriate translation model."""
47
+ model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
48
+ try:
49
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
50
+ model = MarianMTModel.from_pretrained(model_name)
51
+ return model, tokenizer
52
+ except Exception as e:
53
+ print(f"Model loading error: {e}") # Log error
54
+ raise
55
+
56
+ def translate_text(text, model, tokenizer):
57
+ """Translate input text."""
58
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
59
+ translated = model.generate(**inputs)
60
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
61
+ return translated_text
62
+
63
+ @app.post("/translate")
64
+ async def translate(request: TranslationRequest):
65
+ """Translate text from source to target language."""
66
+ if request.src_lang not in supported_languages or request.tgt_lang not in supported_languages:
67
+ raise HTTPException(status_code=400, detail="Unsupported language.")
68
+
69
+ try:
70
+ print(f"Translating text: {request.text} from {request.src_lang} to {request.tgt_lang}") # Debug info
71
+ model, tokenizer = load_model(request.src_lang, request.tgt_lang)
72
+ translated_text = translate_text(request.text, model, tokenizer)
73
+ print(f"Translated text: {translated_text}") # Debug info
74
+ return {"translated_text": translated_text}
75
+ except Exception as e:
76
+ print(f"Error: {e}") # Log error
77
+ raise HTTPException(status_code=500, detail=str(e))
78
+
79
+ @app.post("/recognize")
80
+ async def recognize(language: str = 'en'):
81
+ """Capture speech and recognize it using sounddevice."""
82
+ duration = 10 # Duration of recording in seconds
83
+ samplerate = 16000 # Sample rate for recording
84
+
85
+ try:
86
+ print("Recording audio...") # Debug info
87
+ # Record audio using sounddevice
88
+ recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='int16')
89
+ sd.wait() # Wait until recording is finished
90
+
91
+ # Save audio to a temporary file
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
93
+ write(temp_audio_file.name, samplerate, recording)
94
+ temp_audio_path = temp_audio_file.name
95
+
96
+ # Recognize speech using speech_recognition
97
+ recognizer = sr.Recognizer()
98
+ with sr.AudioFile(temp_audio_path) as source:
99
+ audio = recognizer.record(source)
100
+
101
+ recognized_text = recognizer.recognize_google(audio, language=language)
102
+
103
+ # Clean up temporary audio file
104
+ os.remove(temp_audio_path)
105
+
106
+ return {"recognized_text": recognized_text}
107
+ except sr.UnknownValueError:
108
+ raise HTTPException(status_code=400, detail="Could not understand the audio")
109
+ except sr.RequestError as e:
110
+ raise HTTPException(status_code=500, detail=f"Recognition error: {e}")
111
+ except Exception as e:
112
+ print(f"Error: {e}") # Log error
113
+ raise HTTPException(status_code=500, detail=str(e))
114
+
115
+ @app.post("/speak")
116
+ async def speak(request: SpeakRequest):
117
+ """Speak out a given text."""
118
+ if not request.text:
119
+ raise HTTPException(status_code=400, detail="No text provided.")
120
+ try:
121
+ engine.say(request.text)
122
+ engine.runAndWait()
123
+ return {"message": "Text spoken successfully"}
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=str(e))