lydiasolomon commited on
Commit
58a0d61
·
verified ·
1 Parent(s): 95d4ec9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -27
main.py CHANGED
@@ -8,8 +8,7 @@ from spitch import Spitch
8
  from langchain.prompts import PromptTemplate
9
  from langchain_huggingface import HuggingFaceEndpoint
10
  from langdetect import detect, DetectorFactory
11
- from huggingface_hub.utils import HfHubHTTPError
12
- from smebuilder_vector import retriever # Retriever for context injection
13
 
14
  # ----------------- CONFIG -----------------
15
  DetectorFactory.seed = 0
@@ -96,10 +95,10 @@ Output:
96
  """
97
 
98
  # ----------------- CHAINS -----------------
99
- chat_chain = PromptTemplate(input_variables=["question"], template=chat_template) | llm
100
- stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_template) | llm
101
- autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm
102
- sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm
103
 
104
  # ----------------- REQUEST MODELS -----------------
105
  class ChatRequest(BaseModel):
@@ -119,26 +118,16 @@ def check_auth(authorization: str | None):
119
  raise HTTPException(status_code=403, detail="Invalid token")
120
 
121
  # ----------------- HELPER FUNCTIONS -----------------
122
- def run_chain(chain, input_dict: dict):
123
  """
124
- Safely run a LangChain PromptTemplate | HuggingFaceEndpoint chain.
125
- Returns non-empty string, or detailed error info for debugging.
126
  """
127
  try:
128
- # Render template
129
- prompt_text = chain.prompt.format(**input_dict) if hasattr(chain, "prompt") else str(input_dict)
130
-
131
- # Generate using HuggingFaceEndpoint
132
- output = chain.llm.generate([{"role": "user", "content": prompt_text}])
133
-
134
- # Get text safely
135
- text = getattr(output.generations[0][0], "text", "") or ""
136
- text = text.strip()
137
-
138
- if not text:
139
  return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt_text}
140
-
141
- return text
142
  except Exception:
143
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc(), "prompt": prompt_text}
144
 
@@ -184,14 +173,16 @@ def root():
184
  @app.post("/chat")
185
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
186
  check_auth(authorization)
187
- result = run_chain(chat_chain, {"question": req.question})
 
188
  return result if isinstance(result, dict) else {"reply": result}
189
 
190
  @app.post("/stt")
191
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
192
  check_auth(authorization)
193
  transcription, detected_lang, translation = await process_audio(file, lang_hint)
194
- result = run_chain(stt_chain, {"speech": translation})
 
195
  return {
196
  "transcription": transcription,
197
  "detected_language": detected_lang,
@@ -202,7 +193,8 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
202
  @app.post("/autodoc")
203
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
204
  check_auth(authorization)
205
- result = run_chain(autodoc_chain, {"code": req.code})
 
206
  return result if isinstance(result, dict) else {"documentation": result}
207
 
208
  @app.post("/sme/generate")
@@ -212,7 +204,8 @@ async def sme_generate(payload: dict = Body(...), authorization: str | None = He
212
  user_prompt = payload.get("user_prompt", "")
213
  context_docs = retriever.get_relevant_documents(user_prompt)
214
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
215
- result = run_chain(sme_chain, {"user_prompt": user_prompt, "context": context})
 
216
  return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
217
  except Exception:
218
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
@@ -224,7 +217,8 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
224
  try:
225
  context_docs = retriever.get_relevant_documents(translation)
226
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
227
- result = run_chain(sme_chain, {"user_prompt": translation, "context": context})
 
228
  return {
229
  "success": True,
230
  "transcription": transcription,
 
8
  from langchain.prompts import PromptTemplate
9
  from langchain_huggingface import HuggingFaceEndpoint
10
  from langdetect import detect, DetectorFactory
11
+ from smebuilder_vector import retriever
 
12
 
13
  # ----------------- CONFIG -----------------
14
  DetectorFactory.seed = 0
 
95
  """
96
 
97
  # ----------------- CHAINS -----------------
98
+ chat_chain = PromptTemplate(input_variables=["question"], template=chat_template)
99
+ stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_template)
100
+ autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template)
101
+ sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template)
102
 
103
  # ----------------- REQUEST MODELS -----------------
104
  class ChatRequest(BaseModel):
 
118
  raise HTTPException(status_code=403, detail="Invalid token")
119
 
120
  # ----------------- HELPER FUNCTIONS -----------------
121
+ def run_llm(prompt_text: str):
122
  """
123
+ Directly run HuggingFaceEndpoint with string input.
124
+ Returns text or error dict.
125
  """
126
  try:
127
+ output = llm(prompt_text)
128
+ if not output.strip():
 
 
 
 
 
 
 
 
 
129
  return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt_text}
130
+ return output.strip()
 
131
  except Exception:
132
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc(), "prompt": prompt_text}
133
 
 
173
  @app.post("/chat")
174
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
175
  check_auth(authorization)
176
+ prompt_text = chat_chain.format(question=req.question)
177
+ result = run_llm(prompt_text)
178
  return result if isinstance(result, dict) else {"reply": result}
179
 
180
  @app.post("/stt")
181
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
182
  check_auth(authorization)
183
  transcription, detected_lang, translation = await process_audio(file, lang_hint)
184
+ prompt_text = stt_chain.format(speech=translation)
185
+ result = run_llm(prompt_text)
186
  return {
187
  "transcription": transcription,
188
  "detected_language": detected_lang,
 
193
  @app.post("/autodoc")
194
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
195
  check_auth(authorization)
196
+ prompt_text = autodoc_chain.format(code=req.code)
197
+ result = run_llm(prompt_text)
198
  return result if isinstance(result, dict) else {"documentation": result}
199
 
200
  @app.post("/sme/generate")
 
204
  user_prompt = payload.get("user_prompt", "")
205
  context_docs = retriever.get_relevant_documents(user_prompt)
206
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
207
+ prompt_text = sme_chain.format(user_prompt=user_prompt, context=context)
208
+ result = run_llm(prompt_text)
209
  return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
210
  except Exception:
211
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
 
217
  try:
218
  context_docs = retriever.get_relevant_documents(translation)
219
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
220
+ prompt_text = sme_chain.format(user_prompt=translation, context=context)
221
+ result = run_llm(prompt_text)
222
  return {
223
  "success": True,
224
  "transcription": transcription,