Spaces:
Running
Running
| import os | |
| import logging | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from retrieval import search, EXACT_SI, EXACT_TA, normalize | |
| from intents import detect_smalltalk, smalltalk_reply | |
| from firestore_client import get_advice_by_id | |
| # Optional Qwen output layer | |
| try: | |
| from finetuned_llm import generate_grounded_answer | |
| except Exception: | |
| generate_grounded_answer = None | |
| app = FastAPI(title="Coco-Guide Backend", version="1.3") | |
| # ----------------------------- | |
| # Logging | |
| # ----------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("coco_guide") | |
| # ----------------------------- | |
| # CORS | |
| # ----------------------------- | |
| DEBUG = os.getenv("DEBUG", "true").lower() == "true" | |
| if DEBUG: | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| else: | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://your-frontend-domain.com" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| USE_FINE_TUNED_MODEL = os.getenv("USE_FINE_TUNED_MODEL", "false").lower() == "true" | |
| FALLBACK_THRESHOLD = float(os.getenv("FALLBACK_THRESHOLD", "0.60")) | |
| CLARIFY_THRESHOLD = float(os.getenv("CLARIFY_THRESHOLD", "0.72")) | |
| # ----------------------------- | |
| # Request Schema | |
| # ----------------------------- | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=500) | |
| language: str | |
| # ----------------------------- | |
| # Messages | |
| # ----------------------------- | |
| FALLBACK_SI = "කණගාටුයි, මට සහාය විය හැක්කේ පොල් වගාවට අදාළ කරුණු සඳහා පමණි. කරුණාකර ඔබේ ප්රශ්නය නැවත විමසන්න." | |
| FALLBACK_TA = "மன்னிக்கவும், அந்தத் தகவல் தற்போது எங்களிடம் இல்லை. தயவுசெய்து மேலதிக ஆலோசனைகளுக்கு தென்னை பயிர்ச்செய்கை அதிகாரியைத் தொடர்பு கொள்ளவும்." | |
| CLARIFY_SI = "කරුණාකර ඔබගේ ප්රශ්නය තව විස්තර කරන්න." | |
| CLARIFY_TA = "தயவுசெய்து உங்கள் கேள்வியை மேலும் விளக்கவும்." | |
| LOCATION_FALLBACK_SI = "කණගාටුයි, මෙම පද්ධතිය කුරුණෑගල දිස්ත්රික්කයේ පොල් වගාවට අදාළ උපදෙස් සඳහා පමණක් සීමා වී ඇත." | |
| LOCATION_FALLBACK_TA = "மன்னிக்கவும், இந்த அமைப்பு குருநாகல் மாவட்டத்திலுள்ள தென்னைப் பயிர்ச்செய்கை தொடர்பான ஆலோசனைகளுக்கே மட்டுப்படுத்தப்பட்டுள்ளது." | |
| # ----------------------------- | |
| # Domain / Location Guards | |
| # ----------------------------- | |
| KURUNEGALA_TERMS = { | |
| "kurunegala", "කුරුණෑගල", "குருநாகல்" | |
| } | |
| NON_KURUNEGALA_TERMS = { | |
| "colombo", "කොළඹ", "கொழும்பு", | |
| "gampaha", "ගම්පහ", "கம்பஹா", | |
| "kandy", "මහනුවර", "கண்டி", | |
| "galle", "ගාල්ල", "காலி", | |
| "matara", "මාතර", "மாத்தறை", | |
| "jaffna", "යාපනය", "யாழ்ப்பாணம்", | |
| "batticaloa", "මඩකලපුව", "மட்டக்களப்பு", | |
| "anuradhapura", "අනුරාධපුර", "அனுராதபுரம்", | |
| "polonnaruwa", "පොළොන්නරුව", "பொலன்னறுவை", | |
| "badulla", "බදුල්ල", "பதுளை", | |
| "ratnapura", "රත්නපුර", "இரத்தினபுரி", | |
| "kalutara", "කළුතර", "களுத்துறை", | |
| "trincomalee", "ත්රිකුණාමලය", "திருகோணமலை", | |
| "hambantota", "හම්බන්තොට", "அம்பாந்தோட்டை", | |
| "ampara", "අම්පාර", "அம்பாறை", | |
| "nuwara eliya", "නුවරඑළිය", "நுவரெலியா", | |
| "vavuniya", "වව්නියා", "வவுனியா", | |
| "kilinochchi", "කිලිනොච්චි", "கிளிநொச்சி", | |
| "mannar", "මන්නාරම", "மன்னார்", | |
| "puttalam", "පුත්තලම", "புத்தளம்", | |
| "kegalle", "කෑගල්ල", "கேகாலை", | |
| "monaragala", "මොනරාගල", "மொணராகலை", | |
| } | |
| NON_DOMAIN_TERMS = { | |
| # English | |
| "car", "bike", "phone", "laptop", "school", "exam", "movie", "music", | |
| "politics", "election", "cricket", "football", "passport", "bank", "insurance", | |
| "bus", "train", "airport", "visa", "hotel", "restaurant", "computer", "wifi", | |
| "bitcoin", "tax", "loan", "job", "university", "doctor", "hospital", | |
| "weather", "score", "match", "flight", "ticket", "salary", "mobile", "camera","oil","world", | |
| # Sinhala | |
| "කාර්", "බයික්", "ෆෝන්", "ලැප්ටොප්", "පාසල", "විභාග", "චිත්රපට", | |
| "දේශපාලන", "ක්රිකට්", "පාස්පෝට්", "බැංකු", "රක්ෂණ", | |
| "බස්", "දුම්රිය", "ගුවන් තොටුපළ", "විසා", "හෝටල", "ආපනශාලා", | |
| "කම්පියුටර්", "වයිෆයි", "බදු", "ණය", "රැකියා", "විශ්වවිද්යාල", | |
| "වෛද්ය", "රෝහල", "කාලගුණය", "ලකුණු", "ගුවන් ගමන්", "ටිකට්", "වැටුප්", | |
| "ජංගම", "කැමරා","තෙල්","ලෝකය", | |
| # Tamil | |
| "கார்", "பைக்", "தொலைபேசி", "லாப்டாப்", "பாடசாலை", "தேர்வு", | |
| "திரைப்படம்", "அரசியல்", "கிரிக்கெட்", "காப்பீடு", "வங்கி", "பாஸ்போர்ட்", | |
| "பஸ்", "ரயில்", "விமான நிலையம்", "விசா", "ஹோட்டல்", "உணவகம்", | |
| "கம்ப்யூட்டர்", "வைஃபை", "வரி", "கடன்", "வேலை", "பல்கலைக்கழகம்", | |
| "மருத்துவர்", "மருத்துவமனை", "வானிலை", "மதிப்பெண்", "விமானம்", "டிக்கெட்", | |
| "சம்பளம்", "மொபைல்", "கேமரா","எண்ணெய்","உலகம்" | |
| } | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def _fallback_text(lang: str) -> str: | |
| return FALLBACK_TA if lang == "ta" else FALLBACK_SI | |
| def _clarify_text(lang: str) -> str: | |
| return CLARIFY_TA if lang == "ta" else CLARIFY_SI | |
| def _location_fallback_text(lang: str) -> str: | |
| return LOCATION_FALLBACK_TA if lang == "ta" else LOCATION_FALLBACK_SI | |
| def _json_response( | |
| reply: str, | |
| match_type: str, | |
| category: str, | |
| language: str, | |
| source_id: str = "", | |
| confidence: float = 0.0, | |
| answer_source: str = "", | |
| debug_hits=None, | |
| ): | |
| payload = { | |
| "reply": reply, | |
| "match_type": match_type, | |
| "category": category, | |
| "language": language, | |
| "source_id": source_id, | |
| "confidence": round(float(confidence), 4), | |
| "answer_source": answer_source, | |
| } | |
| if DEBUG and debug_hits is not None: | |
| payload["debug_hits"] = debug_hits | |
| return JSONResponse(content=payload) | |
| def _contains_any_phrase(text: str, phrases: set[str]) -> bool: | |
| t = normalize(text).lower() | |
| phrases_sorted = sorted((p.lower() for p in phrases), key=len, reverse=True) | |
| return any(p in t for p in phrases_sorted) | |
| def _is_outside_kurunegala(text: str) -> bool: | |
| t = normalize(text).lower() | |
| if _contains_any_phrase(t, KURUNEGALA_TERMS): | |
| return False | |
| if _contains_any_phrase(t, NON_KURUNEGALA_TERMS): | |
| return True | |
| return False | |
| def _is_explicitly_non_domain(text: str) -> bool: | |
| return _contains_any_phrase(text, NON_DOMAIN_TERMS) | |
| def startup_checks(): | |
| if FALLBACK_THRESHOLD > CLARIFY_THRESHOLD: | |
| raise ValueError("FALLBACK_THRESHOLD cannot be greater than CLARIFY_THRESHOLD") | |
| logger.info( | |
| { | |
| "event": "startup", | |
| "use_fine_tuned_model": USE_FINE_TUNED_MODEL, | |
| "fallback_threshold": FALLBACK_THRESHOLD, | |
| "clarify_threshold": CLARIFY_THRESHOLD, | |
| "debug": DEBUG, | |
| } | |
| ) | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "use_fine_tuned_model": USE_FINE_TUNED_MODEL, | |
| "fine_tuned_model_available": generate_grounded_answer is not None, | |
| "fallback_threshold": FALLBACK_THRESHOLD, | |
| "clarify_threshold": CLARIFY_THRESHOLD, | |
| "debug": DEBUG, | |
| } | |
| if DEBUG: | |
| def test_firestore(doc_id: str): | |
| try: | |
| doc = get_advice_by_id(doc_id) | |
| if not doc: | |
| return {"ok": False, "error": "Document not found", "doc_id": doc_id} | |
| return {"ok": True, "doc_id": doc_id, "doc": doc} | |
| except Exception as e: | |
| return {"ok": False, "error": str(e), "doc_id": doc_id} | |
| def chat(req: ChatRequest): | |
| msg = (req.message or "").strip() | |
| lang = (req.language or "").strip().lower() | |
| if lang not in {"si", "ta"}: | |
| raise HTTPException(status_code=400, detail="Invalid language. Use 'si' or 'ta'.") | |
| if not msg: | |
| return _json_response( | |
| reply=_clarify_text(lang), | |
| match_type="fallback", | |
| category="empty_input", | |
| language=lang, | |
| source_id="", | |
| confidence=0.0, | |
| answer_source="guard", | |
| ) | |
| user_q = normalize(msg) | |
| # ----------------------------- | |
| # Smalltalk | |
| # ----------------------------- | |
| kind = detect_smalltalk(user_q, lang) | |
| if kind: | |
| return _json_response( | |
| reply=smalltalk_reply(kind, lang), | |
| match_type="smalltalk", | |
| category="", | |
| language=lang, | |
| source_id="", | |
| confidence=1.0, | |
| answer_source="smalltalk", | |
| ) | |
| # ----------------------------- | |
| # Location guard | |
| # ----------------------------- | |
| if _is_outside_kurunegala(user_q): | |
| return _json_response( | |
| reply=_location_fallback_text(lang), | |
| match_type="fallback", | |
| category="out_of_scope_location", | |
| language=lang, | |
| source_id="", | |
| confidence=0.0, | |
| answer_source="guard", | |
| ) | |
| # ----------------------------- | |
| # Explicit non-domain guard | |
| # ----------------------------- | |
| if _is_explicitly_non_domain(user_q): | |
| return _json_response( | |
| reply=_fallback_text(lang), | |
| match_type="fallback", | |
| category="out_of_domain", | |
| language=lang, | |
| source_id="", | |
| confidence=0.0, | |
| answer_source="guard", | |
| ) | |
| best = None | |
| source = "" | |
| confidence = 0.0 | |
| category = "" | |
| debug_hits = None | |
| # ----------------------------- | |
| # Exact Match | |
| # ----------------------------- | |
| if lang == "si" and user_q in EXACT_SI: | |
| best = EXACT_SI[user_q] | |
| source = "exact" | |
| confidence = 1.0 | |
| elif lang == "ta" and user_q in EXACT_TA: | |
| best = EXACT_TA[user_q] | |
| source = "exact" | |
| confidence = 1.0 | |
| else: | |
| # ----------------------------- | |
| # Semantic Search | |
| # ----------------------------- | |
| try: | |
| hits = search(user_q, lang=lang, k=5) | |
| except Exception as e: | |
| logger.exception("Semantic search failed: %s", e) | |
| return _json_response( | |
| reply=_fallback_text(lang), | |
| match_type="error", | |
| category="system_error", | |
| language=lang, | |
| source_id="", | |
| confidence=0.0, | |
| answer_source="error", | |
| ) | |
| if DEBUG: | |
| debug_hits = [ | |
| { | |
| "id": h["id"], | |
| "score": round(h["score"], 4), | |
| "category": h["item"].get("category", ""), | |
| "matched_question": h["matched_question"], | |
| } | |
| for h in hits[:3] | |
| ] | |
| if not hits: | |
| return _json_response( | |
| reply=_fallback_text(lang), | |
| match_type="fallback", | |
| category="unknown", | |
| language=lang, | |
| source_id="", | |
| confidence=0.0, | |
| answer_source="semantic", | |
| debug_hits=debug_hits, | |
| ) | |
| best_hit = hits[0] | |
| top = float(best_hit["score"]) | |
| best = best_hit["item"] | |
| category = best.get("category", "general") | |
| confidence = top | |
| if top < FALLBACK_THRESHOLD: | |
| return _json_response( | |
| reply=_fallback_text(lang), | |
| match_type="fallback", | |
| category=category, | |
| language=lang, | |
| source_id=best_hit.get("id", ""), | |
| confidence=top, | |
| answer_source="semantic", | |
| debug_hits=debug_hits, | |
| ) | |
| if FALLBACK_THRESHOLD <= top < CLARIFY_THRESHOLD: | |
| return _json_response( | |
| reply=_clarify_text(lang), | |
| match_type="clarification", | |
| category=category, | |
| language=lang, | |
| source_id=best_hit.get("id", ""), | |
| confidence=top, | |
| answer_source="semantic", | |
| debug_hits=debug_hits, | |
| ) | |
| source = "semantic" | |
| # ----------------------------- | |
| # Firestore-backed Answer Selection | |
| # ----------------------------- | |
| doc = None | |
| source_id = "" | |
| answer_source = "dataset" | |
| if isinstance(best, dict): | |
| source_id = str(best.get("id", "")).strip() | |
| category = best.get("category", category) | |
| if source_id: | |
| try: | |
| doc = get_advice_by_id(source_id) | |
| except Exception as e: | |
| logger.exception("Firestore lookup failed for source_id=%s: %s", source_id, e) | |
| doc = None | |
| if doc and isinstance(doc, dict): | |
| context_answer = doc.get("answer_ta", "") if lang == "ta" else doc.get("answer_si", "") | |
| category = doc.get("category", category) | |
| answer_source = "firestore" | |
| else: | |
| context_answer = best.get("answer_ta", "") if lang == "ta" else best.get("answer_si", "") | |
| if not context_answer: | |
| return _json_response( | |
| reply=_fallback_text(lang), | |
| match_type="fallback", | |
| category=category or "unknown", | |
| language=lang, | |
| source_id=source_id, | |
| confidence=confidence, | |
| answer_source=answer_source, | |
| debug_hits=debug_hits, | |
| ) | |
| # ----------------------------- | |
| # Optional Qwen Output Layer | |
| # ----------------------------- | |
| used_qwen = False | |
| if USE_FINE_TUNED_MODEL and generate_grounded_answer is not None and source == "semantic": | |
| try: | |
| final_reply = generate_grounded_answer(user_q, context_answer, lang) | |
| used_qwen = True | |
| except Exception as e: | |
| logger.exception("Qwen grounded generation failed: %s", e) | |
| final_reply = context_answer | |
| else: | |
| final_reply = context_answer | |
| logger.info( | |
| { | |
| "message": msg, | |
| "normalized": user_q, | |
| "language": lang, | |
| "match_type": source, | |
| "source_id": source_id, | |
| "category": category, | |
| "confidence": round(confidence, 4), | |
| "answer_source": answer_source, | |
| "used_qwen": used_qwen, | |
| } | |
| ) | |
| return _json_response( | |
| reply=final_reply, | |
| match_type=source, | |
| category=category, | |
| language=lang, | |
| source_id=source_id, | |
| confidence=confidence, | |
| answer_source=answer_source, | |
| debug_hits=debug_hits, | |
| ) |