Spaces:
Running
Running
| """ | |
| Learn Pathophysiology - FastAPI Backend | |
| Serves RAG + LLM API endpoints and Vue frontend static files. | |
| Deploy: HuggingFace Spaces (Docker) or run locally. | |
| """ | |
| import os | |
| import secrets | |
| import logging | |
| from pathlib import Path | |
| from datetime import datetime, timezone, timedelta | |
| import jwt | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from google import genai | |
| from google.genai import types | |
| from google.oauth2 import id_token as google_id_token | |
| from google.auth.transport import requests as google_requests | |
| import chromadb | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") | |
| CHROMA_DIR = os.environ.get("CHROMA_DIR", "../chroma_db") | |
| COLLECTION_NAME = "pathophysiology" | |
| EMBEDDING_MODEL = "gemini-embedding-001" | |
| DEFAULT_MODEL = "gemini-3-flash-preview" | |
| RAG_TOP_K = 5 | |
| # Auth | |
| GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID", "") | |
| JWT_SECRET = os.environ.get("JWT_SECRET", secrets.token_urlsafe(32)) | |
| JWT_ALGORITHM = "HS256" | |
| JWT_EXPIRY_DAYS = 7 | |
| AUTH_ENABLED = bool(GOOGLE_CLIENT_ID) # disable auth if no client ID | |
| AVAILABLE_MODELS = { | |
| "gemini-3-flash-preview": { | |
| "name": "Gemini 3 Flash", | |
| "description": "Najnoviji i najbrzi model", | |
| "icon": "swords", | |
| "wc3_name": "Blademaster", | |
| "tier": "fast", | |
| }, | |
| "gemini-2.5-flash": { | |
| "name": "Gemini 2.5 Flash", | |
| "description": "Brz i pouzdan", | |
| "icon": "bow", | |
| "wc3_name": "Shadow Hunter", | |
| "tier": "fast", | |
| }, | |
| "gemini-2.5-pro": { | |
| "name": "Gemini 2.5 Pro", | |
| "description": "Najpametniji za kompleksne zadatke", | |
| "icon": "mage", | |
| "wc3_name": "Archmage", | |
| "tier": "smart", | |
| }, | |
| } | |
| SYSTEM_PROMPT = """Ti si "Learn Pathophysiology AI", strucni asistent za ucenje patofiziologije | |
| za studente medicine. | |
| ULOGA: | |
| - Objasnjavaš patofiziološke koncepte jasno i precizno | |
| - Koristiš primjere i analogije kad je moguce | |
| - Povezuješ koncepte s klinickom praksom | |
| - Odgovaraš na hrvatskom jeziku | |
| KONTEKST IZ BAZE ZNANJA: | |
| {rag_context} | |
| PRAVILA: | |
| 1. Uvijek citiraj izvor kad koristiš informacije iz konteksta | |
| 2. Ako nisi siguran, reci to otvoreno | |
| 3. Koristi medicinsku terminologiju, ali objasni kompleksne termine | |
| 4. Budi koncizan ali potpun u odgovorima | |
| 5. Odgovaraj na hrvatskom jeziku""" | |
| # ============================================================================= | |
| # SINGLETONS | |
| # ============================================================================= | |
| _genai_client = None | |
| _chroma_collection = None | |
| def get_client(): | |
| global _genai_client | |
| if _genai_client is None: | |
| if not GEMINI_API_KEY: | |
| raise HTTPException(status_code=500, detail="GEMINI_API_KEY not configured") | |
| _genai_client = genai.Client(api_key=GEMINI_API_KEY) | |
| return _genai_client | |
| def get_collection(): | |
| global _chroma_collection | |
| if _chroma_collection is None: | |
| chroma_path = Path(CHROMA_DIR) | |
| if not chroma_path.exists(): | |
| # Try relative to this file | |
| alt_path = Path(__file__).parent.parent.parent / "chroma_db" | |
| if alt_path.exists(): | |
| chroma_path = alt_path | |
| else: | |
| return None | |
| try: | |
| client = chromadb.PersistentClient(path=str(chroma_path)) | |
| _chroma_collection = client.get_collection(COLLECTION_NAME) | |
| except Exception as e: | |
| logger.error(f"ChromaDB error: {e}") | |
| return None | |
| return _chroma_collection | |
| # ============================================================================= | |
| # AUTH HELPERS | |
| # ============================================================================= | |
| def create_jwt(email: str, name: str, picture: str = "") -> str: | |
| payload = { | |
| "sub": email, | |
| "name": name, | |
| "picture": picture, | |
| "iat": datetime.now(timezone.utc), | |
| "exp": datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRY_DAYS), | |
| } | |
| return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| def decode_jwt(token: str) -> dict | None: | |
| try: | |
| return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| except jwt.ExpiredSignatureError: | |
| return None | |
| except jwt.InvalidTokenError: | |
| return None | |
| async def require_auth(request: Request): | |
| """FastAPI dependency — returns user dict or raises 401.""" | |
| if not AUTH_ENABLED: | |
| return {"sub": "anonymous", "name": "Local User"} | |
| auth_header = request.headers.get("Authorization", "") | |
| if not auth_header.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| token = auth_header[7:] | |
| user = decode_jwt(token) | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Invalid or expired token") | |
| return user | |
| # ============================================================================= | |
| # RAG FUNCTIONS | |
| # ============================================================================= | |
| def embed_query(text: str) -> list[float]: | |
| c = get_client() | |
| result = c.models.embed_content(model=EMBEDDING_MODEL, contents=text) | |
| return result.embeddings[0].values | |
| def query_rag(query_text: str, top_k: int = RAG_TOP_K): | |
| coll = get_collection() | |
| if coll is None: | |
| return "Nema dostupnog konteksta.", [] | |
| try: | |
| query_embedding = embed_query(query_text) | |
| results = coll.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| contexts = [] | |
| citations = [] | |
| if results and results["documents"] and results["documents"][0]: | |
| for idx, (doc, meta, dist) in enumerate(zip( | |
| results["documents"][0], | |
| results["metadatas"][0], | |
| results["distances"][0] | |
| )): | |
| contexts.append(doc) | |
| similarity = max(0, 1 - dist / 2) | |
| citations.append({ | |
| "text": doc[:600] + "..." if len(doc) > 600 else doc, | |
| "score": round(similarity, 3), | |
| "source": meta.get("source", "Baza znanja"), | |
| "page_num": meta.get("page_num", "?"), | |
| "rank": idx + 1, | |
| }) | |
| formatted = "\n\n---\n\n".join(contexts) if contexts else "Nema konteksta." | |
| return formatted, citations | |
| except Exception as e: | |
| logger.error(f"RAG error: {e}") | |
| return "Nema dostupnog konteksta.", [] | |
| def generate_chat_response(message: str, history: list, model_name: str = ""): | |
| model_name = model_name or DEFAULT_MODEL | |
| if model_name not in AVAILABLE_MODELS: | |
| model_name = DEFAULT_MODEL | |
| c = get_client() | |
| rag_context, citations = query_rag(message) | |
| system_prompt = SYSTEM_PROMPT.format(rag_context=rag_context) | |
| contents = [system_prompt] | |
| for msg in (history or [])[-10:]: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| contents.append(f"Student: {content}") | |
| else: | |
| contents.append(f"Asistent: {content}") | |
| contents.append(f"Student: {message}") | |
| response = c.models.generate_content( | |
| model=model_name, | |
| contents="\n\n".join(contents), | |
| config=types.GenerateContentConfig( | |
| temperature=0.7, | |
| max_output_tokens=8192, | |
| top_p=0.9, | |
| ) | |
| ) | |
| return response.text, citations | |
| def do_analyze_image(image_bytes: bytes, question: str = "", model_name: str = ""): | |
| model_name = model_name or DEFAULT_MODEL | |
| if model_name not in AVAILABLE_MODELS: | |
| model_name = DEFAULT_MODEL | |
| c = get_client() | |
| # Extract keywords from image | |
| extract_resp = c.models.generate_content( | |
| model=model_name, | |
| contents=[ | |
| types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), | |
| "Izvuci glavni topic i kljucne rijeci s ove stranice. Odgovori kratko." | |
| ], | |
| config=types.GenerateContentConfig(temperature=0.3, max_output_tokens=200) | |
| ) | |
| rag_context, citations = query_rag(extract_resp.text, top_k=3) | |
| if question: | |
| prompt = ( | |
| f"Analiziraj ovu stranicu iz materijala za patofiziologiju " | |
| f"i odgovori na pitanje studenta.\n\n" | |
| f"PITANJE: {question}\n\n" | |
| f"KONTEKST IZ BAZE ZNANJA:\n{rag_context}\n\n" | |
| f"Odgovori detaljno na hrvatskom jeziku." | |
| ) | |
| else: | |
| prompt = ( | |
| f"Analiziraj ovu stranicu iz materijala za patofiziologiju.\n\n" | |
| f"1. Prepoznaj glavni topic\n2. Izvuci kljucne pojmove\n" | |
| f"3. Sazmi glavne tocke\n4. Objasni klinicku vaznost\n\n" | |
| f"KONTEKST IZ BAZE ZNANJA:\n{rag_context}\n\n" | |
| f"Odgovori na hrvatskom jeziku." | |
| ) | |
| response = c.models.generate_content( | |
| model=model_name, | |
| contents=[ | |
| types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), | |
| prompt | |
| ], | |
| config=types.GenerateContentConfig(temperature=0.5, max_output_tokens=8192) | |
| ) | |
| return response.text, citations | |
| # ============================================================================= | |
| # FASTAPI APP | |
| # ============================================================================= | |
| app = FastAPI(title="Learn Pathophysiology API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Request / Response Models --- | |
| class ChatRequest(BaseModel): | |
| message: str | |
| model: str = "" | |
| history: list = [] | |
| # --- Auth Models --- | |
| class GoogleAuthRequest(BaseModel): | |
| credential: str # Google ID token | |
| # --- Auth Endpoints --- | |
| async def auth_config(): | |
| """Tell the frontend if auth is required and the Google Client ID.""" | |
| return { | |
| "auth_enabled": AUTH_ENABLED, | |
| "google_client_id": GOOGLE_CLIENT_ID if AUTH_ENABLED else None, | |
| } | |
| async def auth_google(req: GoogleAuthRequest): | |
| """Verify Google ID token and return a JWT session token.""" | |
| if not AUTH_ENABLED: | |
| raise HTTPException(status_code=400, detail="Auth not enabled") | |
| try: | |
| idinfo = google_id_token.verify_oauth2_token( | |
| req.credential, | |
| google_requests.Request(), | |
| GOOGLE_CLIENT_ID, | |
| ) | |
| email = idinfo.get("email", "") | |
| name = idinfo.get("name", email) | |
| picture = idinfo.get("picture", "") | |
| token = create_jwt(email, name, picture) | |
| return { | |
| "token": token, | |
| "user": {"email": email, "name": name, "picture": picture}, | |
| } | |
| except ValueError as e: | |
| logger.error(f"Google auth failed: {e}") | |
| raise HTTPException(status_code=401, detail="Invalid Google token") | |
| async def auth_me(user=Depends(require_auth)): | |
| """Return the current user's info from their JWT.""" | |
| return { | |
| "email": user.get("sub", ""), | |
| "name": user.get("name", ""), | |
| "picture": user.get("picture", ""), | |
| } | |
| # --- API Endpoints (public) --- | |
| async def health(): | |
| coll = get_collection() | |
| return { | |
| "status": "ok", | |
| "chroma_docs": coll.count() if coll else 0, | |
| "has_api_key": bool(GEMINI_API_KEY), | |
| } | |
| async def list_models(): | |
| return {"models": AVAILABLE_MODELS, "default": DEFAULT_MODEL} | |
| async def stats(): | |
| coll = get_collection() | |
| return { | |
| "documents": coll.count() if coll else 0, | |
| "collection": COLLECTION_NAME, | |
| } | |
| # --- API Endpoints (protected) --- | |
| async def chat(req: ChatRequest, user=Depends(require_auth)): | |
| try: | |
| model = req.model or DEFAULT_MODEL | |
| reply, citations = generate_chat_response(req.message, req.history, model) | |
| return {"reply": reply, "citations": citations, "model_used": model} | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def analyze_image_endpoint( | |
| image: UploadFile = File(...), | |
| question: str = Form(""), | |
| model: str = Form(""), | |
| user=Depends(require_auth), | |
| ): | |
| try: | |
| model_name = model or DEFAULT_MODEL | |
| image_bytes = await image.read() | |
| analysis, citations = do_analyze_image(image_bytes, question, model_name) | |
| return {"analysis": analysis, "citations": citations, "model_used": model_name} | |
| except Exception as e: | |
| logger.error(f"Image analysis error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- Google OAuth Redirect Handler --- | |
| # When using ux_mode='redirect', Google POSTs credential as form data to login_uri | |
| async def google_redirect_callback(request: Request): | |
| """Handle Google Sign-In redirect (POST with credential in form data).""" | |
| form = await request.form() | |
| credential = form.get("credential", "") | |
| if credential and AUTH_ENABLED: | |
| try: | |
| idinfo = google_id_token.verify_oauth2_token( | |
| str(credential), | |
| google_requests.Request(), | |
| GOOGLE_CLIENT_ID, | |
| ) | |
| email = idinfo.get("email", "") | |
| name = idinfo.get("name", email) | |
| picture = idinfo.get("picture", "") | |
| token = create_jwt(email, name, picture) | |
| # Return an HTML page that stores the token and redirects to the app | |
| html = f"""<!DOCTYPE html> | |
| <html><head><title>Logging in...</title></head> | |
| <body><script> | |
| localStorage.setItem('lp_auth_token', '{token}'); | |
| window.location.href = '/'; | |
| </script><p>Logging in...</p></body></html>""" | |
| from fastapi.responses import HTMLResponse | |
| return HTMLResponse(content=html) | |
| except Exception as e: | |
| logger.error(f"Google redirect auth failed: {e}") | |
| # Fallback: serve index.html | |
| static_dir = Path(__file__).parent / "static" | |
| return FileResponse(str(static_dir / "index.html")) | |
| # --- Serve Vue Frontend (production) --- | |
| static_dir = Path(__file__).parent / "static" | |
| if static_dir.exists(): | |
| async def serve_index(): | |
| return FileResponse(str(static_dir / "index.html")) | |
| app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("api:app", host="0.0.0.0", port=7860, reload=True) | |