Spaces:
Sleeping
Sleeping
Commit
·
351bb59
1
Parent(s):
22ecd09
Updated
Browse files
main.py
CHANGED
|
@@ -7,8 +7,8 @@ from spitch import Spitch
|
|
| 7 |
from langchain.prompts import PromptTemplate
|
| 8 |
from langchain_huggingface import HuggingFaceEndpoint
|
| 9 |
from langdetect import detect, DetectorFactory
|
| 10 |
-
from huggingface_hub.utils import HfHubHTTPError
|
| 11 |
-
from smebuilder_vector import retriever #
|
| 12 |
|
| 13 |
# ----------------- CONFIG -----------------
|
| 14 |
DetectorFactory.seed = 0
|
|
@@ -25,14 +25,14 @@ if not SPITCH_API_KEY:
|
|
| 25 |
os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
|
| 26 |
spitch_client = Spitch()
|
| 27 |
|
| 28 |
-
# HuggingFace LLM
|
| 29 |
llm = HuggingFaceEndpoint(
|
| 30 |
repo_id=HF_MODEL,
|
| 31 |
temperature=0.7,
|
| 32 |
top_p=0.9,
|
| 33 |
do_sample=True,
|
| 34 |
repetition_penalty=1.1,
|
| 35 |
-
max_new_tokens=2048
|
| 36 |
)
|
| 37 |
|
| 38 |
# FastAPI app
|
|
@@ -143,8 +143,7 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
|
|
| 143 |
|
| 144 |
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 145 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
| 146 |
-
|
| 147 |
-
tf.write(content)
|
| 148 |
tmp_path = tf.name
|
| 149 |
|
| 150 |
try:
|
|
@@ -156,7 +155,6 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
|
|
| 156 |
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
|
| 157 |
|
| 158 |
transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
|
| 159 |
-
|
| 160 |
try:
|
| 161 |
detected_lang = detect(transcription) if transcription.strip() else "en"
|
| 162 |
except Exception:
|
|
@@ -171,7 +169,6 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
|
|
| 171 |
translation = transcription
|
| 172 |
|
| 173 |
reply = stt_chain.invoke({"speech": translation})
|
| 174 |
-
|
| 175 |
return {
|
| 176 |
"transcription": transcription,
|
| 177 |
"detected_language": detected_lang,
|
|
@@ -189,7 +186,6 @@ def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
|
|
| 189 |
async def sme_generate(payload: dict = Body(...)):
|
| 190 |
try:
|
| 191 |
user_prompt = payload.get("user_prompt", "")
|
| 192 |
-
# retrieve context
|
| 193 |
context_docs = retriever.get_relevant_documents(user_prompt)
|
| 194 |
context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
|
| 195 |
response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
|
|
@@ -205,8 +201,7 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
|
|
| 205 |
|
| 206 |
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 207 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
| 208 |
-
|
| 209 |
-
tf.write(content)
|
| 210 |
tmp_path = tf.name
|
| 211 |
|
| 212 |
try:
|
|
@@ -218,7 +213,6 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
|
|
| 218 |
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
|
| 219 |
|
| 220 |
transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
|
| 221 |
-
|
| 222 |
try:
|
| 223 |
detected_lang = detect(transcription) if transcription.strip() else "en"
|
| 224 |
except Exception:
|
|
@@ -233,10 +227,8 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
|
|
| 233 |
translation = transcription
|
| 234 |
|
| 235 |
try:
|
| 236 |
-
# vector retrieval here too
|
| 237 |
context_docs = retriever.get_relevant_documents(translation)
|
| 238 |
context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
|
| 239 |
-
|
| 240 |
sme_response = sme_chain.invoke({"user_prompt": translation, "context": context})
|
| 241 |
return {
|
| 242 |
"success": True,
|
|
@@ -250,7 +242,7 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
|
|
| 250 |
return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
|
| 251 |
raise e
|
| 252 |
|
| 253 |
-
#
|
| 254 |
if __name__ == "__main__":
|
| 255 |
import uvicorn
|
| 256 |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
|
|
|
|
| 7 |
from langchain.prompts import PromptTemplate
|
| 8 |
from langchain_huggingface import HuggingFaceEndpoint
|
| 9 |
from langdetect import detect, DetectorFactory
|
| 10 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 11 |
+
from smebuilder_vector import retriever # your retriever
|
| 12 |
|
| 13 |
# ----------------- CONFIG -----------------
|
| 14 |
DetectorFactory.seed = 0
|
|
|
|
| 25 |
os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
|
| 26 |
spitch_client = Spitch()
|
| 27 |
|
| 28 |
+
# HuggingFace LLM
|
| 29 |
llm = HuggingFaceEndpoint(
|
| 30 |
repo_id=HF_MODEL,
|
| 31 |
temperature=0.7,
|
| 32 |
top_p=0.9,
|
| 33 |
do_sample=True,
|
| 34 |
repetition_penalty=1.1,
|
| 35 |
+
max_new_tokens=2048
|
| 36 |
)
|
| 37 |
|
| 38 |
# FastAPI app
|
|
|
|
| 143 |
|
| 144 |
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 145 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
| 146 |
+
tf.write(await file.read())
|
|
|
|
| 147 |
tmp_path = tf.name
|
| 148 |
|
| 149 |
try:
|
|
|
|
| 155 |
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
|
| 156 |
|
| 157 |
transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
|
|
|
|
| 158 |
try:
|
| 159 |
detected_lang = detect(transcription) if transcription.strip() else "en"
|
| 160 |
except Exception:
|
|
|
|
| 169 |
translation = transcription
|
| 170 |
|
| 171 |
reply = stt_chain.invoke({"speech": translation})
|
|
|
|
| 172 |
return {
|
| 173 |
"transcription": transcription,
|
| 174 |
"detected_language": detected_lang,
|
|
|
|
| 186 |
async def sme_generate(payload: dict = Body(...)):
|
| 187 |
try:
|
| 188 |
user_prompt = payload.get("user_prompt", "")
|
|
|
|
| 189 |
context_docs = retriever.get_relevant_documents(user_prompt)
|
| 190 |
context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
|
| 191 |
response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
|
|
|
|
| 201 |
|
| 202 |
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 203 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
| 204 |
+
tf.write(await file.read())
|
|
|
|
| 205 |
tmp_path = tf.name
|
| 206 |
|
| 207 |
try:
|
|
|
|
| 213 |
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
|
| 214 |
|
| 215 |
transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
|
|
|
|
| 216 |
try:
|
| 217 |
detected_lang = detect(transcription) if transcription.strip() else "en"
|
| 218 |
except Exception:
|
|
|
|
| 227 |
translation = transcription
|
| 228 |
|
| 229 |
try:
|
|
|
|
| 230 |
context_docs = retriever.get_relevant_documents(translation)
|
| 231 |
context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
|
|
|
|
| 232 |
sme_response = sme_chain.invoke({"user_prompt": translation, "context": context})
|
| 233 |
return {
|
| 234 |
"success": True,
|
|
|
|
| 242 |
return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
|
| 243 |
raise e
|
| 244 |
|
| 245 |
+
# ----------------- MAIN -----------------
|
| 246 |
if __name__ == "__main__":
|
| 247 |
import uvicorn
|
| 248 |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
|