Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import tempfile | |
| import traceback | |
| from typing import Optional | |
| from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from spitch import Spitch | |
| from langchain.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langdetect import detect, DetectorFactory | |
| from huggingface_hub.utils import HfHubHTTPError | |
| from smebuilder_vector import retriever # retriever that exposes .get_relevant_documents(...) | |
| DetectorFactory.seed = 0 | |
| # ----------------- CONFIG ----------------- | |
| SPITCH_API_KEY = os.getenv("SPITCH_API_KEY") | |
| HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct") | |
| FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*") | |
| PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "") | |
| if not SPITCH_API_KEY: | |
| raise RuntimeError("Set SPITCH_API_KEY in environment before starting.") | |
| # Init Spitch | |
| os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY | |
| spitch_client = Spitch() | |
| # HuggingFace LLM | |
| # NOTE: pass generation params explicitly (pydantic validation requires explicit params) | |
| llm = HuggingFaceEndpoint( | |
| repo_id=HF_MODEL, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| repetition_penalty=1.1, | |
| max_new_tokens=2048, | |
| ) | |
| # FastAPI app | |
| app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)") | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["Authorization", "Content-Type"], | |
| ) | |
| # ----------------- PROMPT TEMPLATES ----------------- | |
| chat_template = """You are DevAssist, an AI coding assistant. | |
| Guidelines: | |
| - Always format responses in Markdown. | |
| - Do NOT use inline formatting like bold (**), italics (*), or underlines. | |
| - Use plain text section headers (e.g., Explanation:, Steps to Fix:, Fixed Code:). | |
| - Use bullet points (-, 1.) for steps. | |
| - Use fenced code blocks (```python ... ```) for code. | |
| - Be friendly yet professional; explain step by step. | |
| Question: {question} | |
| Answer: | |
| """ | |
| stt_chat_template = """You are DevAssist, an AI coding assistant. | |
| - The input is transcribed speech. Interpret it as a dev question. | |
| - Provide clear answers with code examples (use markdown triple backticks). | |
| - If input is unclear, ask a clarifying question. | |
| Spoken Question: {speech} | |
| Answer: | |
| """ | |
| autodoc_template = """You are DevAssist DocBot. | |
| - Read the code and produce professional documentation in markdown. | |
| Code: {code} | |
| Documentation: | |
| """ | |
| sme_template = """ | |
| You are a senior full-stack engineer specializing in modern front-end development. | |
| Your job is to generate production-ready code for websites and apps. | |
| Guidelines: | |
| - Always return three separate files: index.html, styles.css, and script.js | |
| - HTML must be semantic, responsive, and mobile-first (include <meta name="viewport">) | |
| - CSS should use Flexbox/Grid and include hover/transition effects | |
| - JavaScript should add interactivity (e.g. button actions, basic animations, toggles) | |
| - Include a hero section, a feature grid, testimonials, and footer | |
| - Use realistic content (avoid lorem ipsum), sensible copy, and accessible markup | |
| - Return only valid JSON with the keys: "files" -> { "index.html": "...", "styles.css": "...", "script.js": "..." } | |
| User Prompt: | |
| {user_prompt} | |
| Context: | |
| {context} | |
| Return: | |
| """ | |
| # ----------------- CHAINS ----------------- | |
| chat_chain = PromptTemplate(input_variables=["question"], template=chat_template) | llm | |
| stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_template) | llm | |
| autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template) | llm | |
| sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template) | llm | |
| # ----------------- REQUEST MODELS ----------------- | |
| class ChatRequest(BaseModel): | |
| question: str | |
| class AutoDocRequest(BaseModel): | |
| code: str | |
| # ----------------- AUTH ----------------- | |
| def check_auth(authorization: Optional[str] = None): | |
| if not PROJECT_API_KEY: | |
| # No API key enforced in this environment | |
| return | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Missing bearer token") | |
| token = authorization.split(" ", 1)[1] | |
| if token != PROJECT_API_KEY: | |
| raise HTTPException(status_code=403, detail="Invalid token") | |
| # ----------------- HELPERS ----------------- | |
| def try_parse_json(maybe_str: str): | |
| """Try to parse JSON; if fails, return None.""" | |
| try: | |
| return json.loads(maybe_str) | |
| except Exception: | |
| # attempt to find a JSON substring | |
| import re | |
| m = re.search(r"\{[\s\S]*\}\s*$", maybe_str.strip()) | |
| if m: | |
| try: | |
| return json.loads(m.group(0)) | |
| except Exception: | |
| return None | |
| return None | |
| # ----------------- ENDPOINTS ----------------- | |
| def root(): | |
| return {"status": "DevAssist AI Backend running"} | |
| def chat(req: ChatRequest, authorization: Optional[str] = Header(None)): | |
| check_auth(authorization) | |
| try: | |
| answer = chat_chain.invoke({"question": req.question}) | |
| return {"reply": answer.strip() if isinstance(answer, str) else str(answer)} | |
| except HfHubHTTPError as e: | |
| if "exceeded" in str(e).lower() or "quota" in str(e).lower(): | |
| return {"reply": "⚠️ Daily token limit reached. Try again in 24 hours."} | |
| raise e | |
| async def stt_audio( | |
| file: UploadFile = File(...), | |
| lang_hint: Optional[str] = None, | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| check_auth(authorization) | |
| suffix = os.path.splitext(file.filename)[1] or ".wav" | |
| # create temp file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf: | |
| content = await file.read() | |
| tf.write(content) | |
| tmp_path = tf.name | |
| try: | |
| # transcribe | |
| if lang_hint: | |
| resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read()) | |
| else: | |
| resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read()) | |
| except Exception: | |
| # fallback to english transcription if something fails | |
| resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read()) | |
| transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "") | |
| try: | |
| detected_lang = detect(transcription) if transcription.strip() else "en" | |
| except Exception: | |
| detected_lang = "en" | |
| translation = transcription | |
| if detected_lang != "en": | |
| try: | |
| translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en") | |
| translation = getattr(translation_resp, "text", "") or (translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation) | |
| except Exception: | |
| translation = transcription | |
| # call the STT chain (LLM) | |
| try: | |
| reply = stt_chain.invoke({"speech": translation}) | |
| except Exception as e: | |
| # on LLM problems return transcription anyway | |
| reply = f"(LLM error) Transcription: {translation}" | |
| # cleanup temp file to avoid storage bloat | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| return { | |
| "transcription": transcription, | |
| "detected_language": detected_lang, | |
| "translation": translation, | |
| "reply": reply.strip() if isinstance(reply, str) else str(reply), | |
| } | |
| def autodoc(req: AutoDocRequest, authorization: Optional[str] = Header(None)): | |
| check_auth(authorization) | |
| docs = autodoc_chain.invoke({"code": req.code}) | |
| return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)} | |
| async def sme_generate(payload: dict = Body(...), authorization: Optional[str] = Header(None)): | |
| """ | |
| Payload expected: { "user_prompt": "Create ...", (optionally) "force_simple": true } | |
| Returns: success, data (if success) or error | |
| """ | |
| check_auth(authorization) | |
| user_prompt = payload.get("user_prompt", "") | |
| if not user_prompt or not user_prompt.strip(): | |
| raise HTTPException(status_code=400, detail="user_prompt is required") | |
| # Get context from retriever (if available) | |
| try: | |
| context_docs = retriever.get_relevant_documents(user_prompt) if retriever else [] | |
| context = "\n\n".join([getattr(d, "page_content", str(d)) for d in context_docs]) if context_docs else "No extra context" | |
| except Exception: | |
| context = "No extra context" | |
| # Invoke SME chain | |
| try: | |
| raw = sme_chain.invoke({"user_prompt": user_prompt, "context": context}) | |
| # Try to parse returned JSON | |
| parsed = None | |
| if isinstance(raw, str): | |
| parsed = try_parse_json(raw) | |
| elif isinstance(raw, dict): | |
| parsed = raw | |
| if parsed: | |
| return {"success": True, "data": parsed} | |
| else: | |
| # If model didn't return strict JSON, return helpful error + raw output so frontend can show it | |
| return {"success": False, "error": "LLM did not return valid JSON", "raw": raw} | |
| except HfHubHTTPError as e: | |
| if "exceeded" in str(e).lower() or "quota" in str(e).lower(): | |
| return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."} | |
| raise e | |
| except Exception as e: | |
| # Debug info for devs (but don't leak sensitive internals in production) | |
| return {"success": False, "error": "SME generation failed", "details": str(e), "trace": traceback.format_exc()} | |
| async def sme_speech_generate( | |
| file: UploadFile = File(...), | |
| lang_hint: Optional[str] = None, | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| check_auth(authorization) | |
| suffix = os.path.splitext(file.filename)[1] or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf: | |
| content = await file.read() | |
| tf.write(content) | |
| tmp_path = tf.name | |
| try: | |
| if lang_hint: | |
| resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read()) | |
| else: | |
| resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read()) | |
| except Exception: | |
| resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read()) | |
| transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "") | |
| try: | |
| detected_lang = detect(transcription) if transcription.strip() else "en" | |
| except Exception: | |
| detected_lang = "en" | |
| translation = transcription | |
| if detected_lang != "en": | |
| try: | |
| translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en") | |
| translation = getattr(translation_resp, "text", "") or (translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation) | |
| except Exception: | |
| translation = transcription | |
| # Get context docs for the transcribed prompt | |
| try: | |
| context_docs = retriever.get_relevant_documents(translation) if retriever else [] | |
| context = "\n\n".join([getattr(d, "page_content", str(d)) for d in context_docs]) if context_docs else "No extra context" | |
| except Exception: | |
| context = "No extra context" | |
| # Invoke SME chain | |
| try: | |
| raw = sme_chain.invoke({"user_prompt": translation, "context": context}) | |
| parsed = None | |
| if isinstance(raw, str): | |
| parsed = try_parse_json(raw) | |
| elif isinstance(raw, dict): | |
| parsed = raw | |
| # cleanup tmp file | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| if parsed: | |
| return { | |
| "success": True, | |
| "transcription": transcription, | |
| "detected_language": detected_lang, | |
| "translation": translation, | |
| "sme_site": parsed, | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": "LLM did not return valid JSON", | |
| "raw": raw, | |
| "transcription": transcription, | |
| "detected_language": detected_lang, | |
| "translation": translation, | |
| } | |
| except HfHubHTTPError as e: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| if "exceeded" in str(e).lower() or "quota" in str(e).lower(): | |
| return {"success": False, "error": "⚠️ Token quota for today has been used. Please come back in 24 hours."} | |
| raise e | |
| except Exception as e: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| return {"success": False, "error": "SME generation failed", "details": str(e), "trace": traceback.format_exc()} | |
| # ----------------- MAIN ----------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False) | |