lydiasolomon commited on
Commit
2bab06e
·
verified ·
1 Parent(s): f14c1fe

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -86
main.py CHANGED
@@ -1,8 +1,6 @@
1
  import os
2
  import tempfile
3
- import logging
4
- from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body, Request
5
- from fastapi.responses import JSONResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
  from spitch import Spitch
@@ -12,17 +10,13 @@ from langdetect import detect, DetectorFactory
12
  from huggingface_hub.utils import HfHubHTTPError
13
  from smebuilder_vector import retriever # Retriever for context injection
14
 
15
- # ----------------- LOGGING -----------------
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger("DevAssist")
18
-
19
  # ----------------- CONFIG -----------------
20
  DetectorFactory.seed = 0
21
 
22
  SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
23
  HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct")
24
  FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
25
- PROJECT_API_KEY = os.getenv("PROJECT_API_KEY") # default if not set
26
 
27
  if not SPITCH_API_KEY:
28
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
@@ -41,31 +35,7 @@ llm = HuggingFaceEndpoint(
41
  max_new_tokens=2048
42
  )
43
 
44
- # ----------------- HELPERS -----------------
45
- def run_llm_model(chain, payload: dict):
46
- """
47
- Safely run HuggingFace model through LangChain chain.
48
- Handles string, dict, and list responses without crashing.
49
- """
50
- try:
51
- result = chain.invoke(payload)
52
- logger.info(f"HF raw response: {result}")
53
-
54
- if isinstance(result, str):
55
- return result.strip()
56
-
57
- if isinstance(result, dict) and "generated_text" in result:
58
- return result["generated_text"].strip()
59
-
60
- if isinstance(result, list) and len(result) > 0 and "generated_text" in result[0]:
61
- return result[0]["generated_text"].strip()
62
-
63
- return str(result).strip()
64
- except Exception as e:
65
- logger.error(f"LLM execution failed: {e}")
66
- return f"⚠️ LLM error: {str(e)}"
67
-
68
- # ----------------- FASTAPI -----------------
69
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
70
 
71
  # CORS
@@ -77,13 +47,7 @@ app.add_middleware(
77
  allow_headers=["Authorization", "Content-Type"],
78
  )
79
 
80
- # Global exception handler
81
- @app.exception_handler(Exception)
82
- async def global_exception_handler(request: Request, exc: Exception):
83
- logger.error(f"Unhandled error: {exc}")
84
- return JSONResponse(status_code=500, content={"error": str(exc)})
85
-
86
- # ----------------- PROMPTS -----------------
87
  chat_template = """You are DevAssist, an AI coding assistant.
88
 
89
  Guidelines:
@@ -114,6 +78,7 @@ Code: {code}
114
  Documentation:
115
  """
116
 
 
117
  sme_template = """
118
  You are a senior full-stack engineer specializing in modern front-end development.
119
  Your job is to generate **production-ready code** for websites and apps.
@@ -125,7 +90,7 @@ Guidelines:
125
  - JavaScript must add interactivity (animations, toggles, button actions)
126
  - Include hero, feature grid, testimonials, and footer
127
  - Use realistic content (no lorem ipsum, no placeholders)
128
- - Return ONLY valid JSON: { "files": { "index.html": "...", "styles.css": "...", "script.js": "..." } }
129
 
130
  Prompt: {user_prompt}
131
  Context: {context}
@@ -139,7 +104,7 @@ stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_templat
139
  autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm
140
  sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm
141
 
142
- # ----------------- MODELS -----------------
143
  class ChatRequest(BaseModel):
144
  question: str
145
 
@@ -148,7 +113,8 @@ class AutoDocRequest(BaseModel):
148
 
149
  # ----------------- AUTH -----------------
150
  def check_auth(authorization: str | None):
151
- if not PROJECT_API_KEY: # if no key set, skip
 
152
  return
153
  if not authorization or not authorization.startswith("Bearer "):
154
  raise HTTPException(status_code=401, detail="Missing bearer token")
@@ -164,18 +130,28 @@ def root():
164
  @app.post("/chat")
165
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
166
  check_auth(authorization)
167
- return {"reply": run_llm_model(chat_chain, {"question": req.question})}
 
 
 
 
 
 
168
 
169
  @app.post("/stt")
170
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
171
  check_auth(authorization)
 
172
  suffix = os.path.splitext(file.filename)[1] or ".wav"
173
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
174
  tf.write(await file.read())
175
  tmp_path = tf.name
176
 
177
  try:
178
- resp = spitch_client.speech.transcribe(language=lang_hint or "en", content=open(tmp_path, "rb").read())
 
 
 
179
  except Exception:
180
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
181
 
@@ -194,56 +170,33 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
194
  except Exception:
195
  translation = transcription
196
 
197
- reply = run_llm_model(stt_chain, {"speech": translation})
198
- return {"transcription": transcription, "detected_language": detected_lang, "translation": translation, "reply": reply}
 
 
 
 
 
199
 
200
  @app.post("/autodoc")
201
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
202
  check_auth(authorization)
203
- docs = run_llm_model(autodoc_chain, {"code": req.code})
204
- return {"documentation": docs}
205
 
206
  @app.post("/sme/generate")
207
  async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
208
  check_auth(authorization)
209
- user_prompt = payload.get("user_prompt", "")
210
- context_docs = retriever.get_relevant_documents(user_prompt)
211
- context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
212
- response = run_llm_model(sme_chain, {"user_prompt": user_prompt, "context": context})
213
- return {"success": True, "data": response}
214
-
215
- @app.post("/sme/speech-generate")
216
- async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
217
- check_auth(authorization)
218
- suffix = os.path.splitext(file.filename)[1] or ".wav"
219
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
220
- tf.write(await file.read())
221
- tmp_path = tf.name
222
-
223
- try:
224
- resp = spitch_client.speech.transcribe(language=lang_hint or "en", content=open(tmp_path, "rb").read())
225
- except Exception:
226
- resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
227
-
228
- transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
229
- detected_lang = "en"
230
  try:
231
- detected_lang = detect(transcription) if transcription.strip() else "en"
232
- except Exception:
233
- pass
234
-
235
- translation = transcription
236
- if detected_lang != "en":
237
- try:
238
- translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
239
- translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "")
240
- except Exception:
241
- translation = transcription
242
-
243
- context_docs = retriever.get_relevant_documents(translation)
244
- context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
245
- sme_response = run_llm_model(sme_chain, {"user_prompt": translation, "context": context})
246
- return {"success": True, "transcription": transcription, "detected_language": detected_lang, "translation": translation, "sme_site": sme_response}
247
 
248
  # ----------------- MAIN -----------------
249
  if __name__ == "__main__":
 
1
  import os
2
  import tempfile
3
+ from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
 
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
  from spitch import Spitch
 
10
  from huggingface_hub.utils import HfHubHTTPError
11
  from smebuilder_vector import retriever # Retriever for context injection
12
 
 
 
 
 
13
  # ----------------- CONFIG -----------------
14
  DetectorFactory.seed = 0
15
 
16
  SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
17
  HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct")
18
  FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
19
+ PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
20
 
21
  if not SPITCH_API_KEY:
22
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
 
35
  max_new_tokens=2048
36
  )
37
 
38
+ # FastAPI app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
40
 
41
  # CORS
 
47
  allow_headers=["Authorization", "Content-Type"],
48
  )
49
 
50
+ # ----------------- PROMPT TEMPLATES -----------------
 
 
 
 
 
 
51
  chat_template = """You are DevAssist, an AI coding assistant.
52
 
53
  Guidelines:
 
78
  Documentation:
79
  """
80
 
81
+ # 🔥 Fixed SME template with escaped curly braces
82
  sme_template = """
83
  You are a senior full-stack engineer specializing in modern front-end development.
84
  Your job is to generate **production-ready code** for websites and apps.
 
90
  - JavaScript must add interactivity (animations, toggles, button actions)
91
  - Include hero, feature grid, testimonials, and footer
92
  - Use realistic content (no lorem ipsum, no placeholders)
93
+ - Return ONLY valid JSON: {{ "files": {{ "index.html": "...", "styles.css": "...", "script.js": "..." }} }}
94
 
95
  Prompt: {user_prompt}
96
  Context: {context}
 
104
  autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm
105
  sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm
106
 
107
+ # ----------------- REQUEST MODELS -----------------
108
  class ChatRequest(BaseModel):
109
  question: str
110
 
 
113
 
114
  # ----------------- AUTH -----------------
115
  def check_auth(authorization: str | None):
116
+ """Validate Bearer token against PROJECT_API_KEY"""
117
+ if not PROJECT_API_KEY:
118
  return
119
  if not authorization or not authorization.startswith("Bearer "):
120
  raise HTTPException(status_code=401, detail="Missing bearer token")
 
130
  @app.post("/chat")
131
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
132
  check_auth(authorization)
133
+ try:
134
+ answer = chat_chain.invoke({"question": req.question})
135
+ return {"reply": answer.strip() if isinstance(answer, str) else str(answer)}
136
+ except HfHubHTTPError as e:
137
+ if "exceeded" in str(e).lower() or "quota" in str(e).lower():
138
+ return {"reply": "⚠️ Daily token limit reached. Try again in 24 hours."}
139
+ raise e
140
 
141
  @app.post("/stt")
142
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
143
  check_auth(authorization)
144
+
145
  suffix = os.path.splitext(file.filename)[1] or ".wav"
146
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
147
  tf.write(await file.read())
148
  tmp_path = tf.name
149
 
150
  try:
151
+ if lang_hint:
152
+ resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
153
+ else:
154
+ resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
155
  except Exception:
156
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
157
 
 
170
  except Exception:
171
  translation = transcription
172
 
173
+ reply = stt_chain.invoke({"speech": translation})
174
+ return {
175
+ "transcription": transcription,
176
+ "detected_language": detected_lang,
177
+ "translation": translation,
178
+ "reply": reply.strip() if isinstance(reply, str) else str(reply)
179
+ }
180
 
181
  @app.post("/autodoc")
182
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
183
  check_auth(authorization)
184
+ docs = autodoc_chain.invoke({"code": req.code})
185
+ return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)}
186
 
187
  @app.post("/sme/generate")
188
  async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
189
  check_auth(authorization)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  try:
191
+ user_prompt = payload.get("user_prompt", "")
192
+ context_docs = retriever.get_relevant_documents(user_prompt)
193
+ context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
194
+ response = sme_chain.invoke({"user_prompt": user_prompt, "context": context})
195
+ return {"success": True, "data": response}
196
+ except HfHubHTTPError as e:
197
+ if "exceeded" in str(e).lower() or "quota" in str(e).lower():
198
+ return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."}
199
+ raise e
 
 
 
 
 
 
 
200
 
201
  # ----------------- MAIN -----------------
202
  if __name__ == "__main__":