Abhiru1 commited on
Commit
733f19b
·
verified ·
1 Parent(s): 449716d

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +473 -0
main.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import JSONResponse
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel, Field
8
+
9
+ from retrieval import search, EXACT_SI, EXACT_TA, normalize
10
+ from intents import detect_smalltalk, smalltalk_reply
11
+ from firestore_client import get_advice_by_id
12
+
13
+ # Optional Qwen output layer
14
+ try:
15
+ from finetuned_llm import generate_grounded_answer
16
+ except Exception:
17
+ generate_grounded_answer = None
18
+
19
+
20
+ app = FastAPI(title="Coco-Guide Backend", version="1.3")
21
+
22
+
23
+ # -----------------------------
24
+ # Logging
25
+ # -----------------------------
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger("coco_guide")
28
+
29
+
30
+ # -----------------------------
31
+ # CORS
32
+ # -----------------------------
33
+ DEBUG = os.getenv("DEBUG", "true").lower() == "true"
34
+
35
+ if DEBUG:
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=False,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+ else:
44
+ app.add_middleware(
45
+ CORSMiddleware,
46
+ allow_origins=[
47
+ "https://your-frontend-domain.com"
48
+ ],
49
+ allow_credentials=True,
50
+ allow_methods=["*"],
51
+ allow_headers=["*"],
52
+ )
53
+
54
+
55
+ # -----------------------------
56
+ # Config
57
+ # -----------------------------
58
+ USE_FINE_TUNED_MODEL = os.getenv("USE_FINE_TUNED_MODEL", "false").lower() == "true"
59
+ FALLBACK_THRESHOLD = float(os.getenv("FALLBACK_THRESHOLD", "0.70"))
60
+ CLARIFY_THRESHOLD = float(os.getenv("CLARIFY_THRESHOLD", "0.80"))
61
+
62
+
63
+ # -----------------------------
64
+ # Request Schema
65
+ # -----------------------------
66
+ class ChatRequest(BaseModel):
67
+ message: str = Field(..., min_length=1, max_length=500)
68
+ language: str
69
+
70
+
71
+ # -----------------------------
72
+ # Messages
73
+ # -----------------------------
74
+ FALLBACK_SI = "කණගාටුයි, මට සහාය විය හැක්කේ පොල් වගාවට අදාළ කරුණු සඳහා පමණි. කරුණාකර ඔබේ ප්‍රශ්නය නැවත විමසන්න."
75
+ FALLBACK_TA = "மன்னிக்கவும், அந்தத் தகவல் தற்போது எங்களிடம் இல்லை. தயவுசெய்து மேலதிக ஆலோசனைகளுக்கு தென்னை பயிர்ச்செய்கை அதிகாரியைத் தொடர்பு கொள்ளவும்."
76
+
77
+ CLARIFY_SI = "කරුණාකර ඔබගේ ප්‍රශ්නය තව විස්තර කරන්න."
78
+ CLARIFY_TA = "தயவுசெய்து உங்கள் கேள்வியை மேலும் விளக்கவும்."
79
+
80
+ LOCATION_FALLBACK_SI = "කණගාටුයි, මෙම පද්ධතිය කුරුණෑගල දිස්ත්‍රික්කයේ පොල් වගාවට අදාළ උපදෙස් සඳහා පමණක් සීමා වී ඇත."
81
+ LOCATION_FALLBACK_TA = "மன்னிக்கவும், இந்த அமைப்பு குருநாகல் மாவட்டத்திலுள்ள தென்னைப் பயிர்ச்செய்கை தொடர்பான ஆலோசனைகளுக்கே மட்டுப்படுத்தப்பட்டுள்ளது."
82
+
83
+
84
+ # -----------------------------
85
+ # Domain / Location Guards
86
+ # -----------------------------
87
+ KURUNEGALA_TERMS = {
88
+ "kurunegala", "කුරුණෑගල", "குருநாகல்"
89
+ }
90
+
91
+ NON_KURUNEGALA_TERMS = {
92
+ "colombo", "කොළඹ", "கொழும்பு",
93
+ "gampaha", "ගම්පහ", "கம்பஹா",
94
+ "kandy", "මහනුවර", "கண்டி",
95
+ "galle", "ගාල්ල", "காலி",
96
+ "matara", "මාතර", "மாத்தறை",
97
+ "jaffna", "යාපනය", "யாழ்ப்பாணம்",
98
+ "batticaloa", "මඩකලපුව", "மட்டக்களப்பு",
99
+ "anuradhapura", "අනුරාධපුර", "அனுராதபுரம்",
100
+ "polonnaruwa", "පොළොන්නරුව", "பொலன்னறுவை",
101
+ "badulla", "බදුල්ල", "பதுளை",
102
+ "ratnapura", "රත්නපුර", "இரத்தினபுரி",
103
+ "kalutara", "කළුතර", "களுத்துறை",
104
+ "trincomalee", "ත්‍රිකුණාමලය", "திருகோணமலை",
105
+ "hambantota", "හම්බන්තොට", "அம்பாந்தோட்டை",
106
+ "ampara", "අම්පාර", "அம்பாறை",
107
+ "nuwara eliya", "නුවරඑළිය", "நுவரெலியா",
108
+ "vavuniya", "වව්නියා", "வவுனியா",
109
+ "kilinochchi", "කිලිනොච්චි", "கிளிநொச்சி",
110
+ "mannar", "මන්නාරම", "மன்னார்",
111
+ "puttalam", "පුත්තලම", "புத்தளம்",
112
+ "kegalle", "කෑගල්ල", "கேகாலை",
113
+ "monaragala", "මොනරාගල", "மொணராகலை",
114
+ }
115
+
116
+ NON_DOMAIN_TERMS = {
117
+ # English
118
+ "car", "bike", "phone", "laptop", "school", "exam", "movie", "music",
119
+ "politics", "election", "cricket", "football", "passport", "bank", "insurance",
120
+ "bus", "train", "airport", "visa", "hotel", "restaurant", "computer", "wifi",
121
+ "bitcoin", "tax", "loan", "job", "university", "doctor", "hospital",
122
+ "weather", "score", "match", "flight", "ticket", "salary", "mobile", "camera","oil","world",
123
+
124
+ # Sinhala
125
+ "කාර්", "බයික්", "ෆෝන්", "ලැප්ටොප්", "පාසල", "විභාග", "චිත්‍රපට",
126
+ "දේශපාලන", "ක්‍රිකට්", "පාස්පෝට්", "බැංකු", "රක්ෂණ",
127
+ "බස්", "දුම්රිය", "ගුවන් තොටුපළ", "විසා", "හෝටල", "ආපනශාලා",
128
+ "කම්පියුටර්", "වයිෆයි", "බදු", "ණය", "රැකියා", "විශ්වවිද්‍යාල",
129
+ "වෛද්‍ය", "රෝහල", "කාලගුණය", "ලකුණු", "ගුවන් ගමන්", "ටිකට්", "වැටුප්",
130
+ "ජංගම", "කැමරා","තෙල්","ලෝකය",
131
+
132
+ # Tamil
133
+ "கார்", "பைக்", "தொலைபேசி", "லாப்டாப்", "பாடசாலை", "தேர்வு",
134
+ "திரைப்படம்", "அரசியல்", "கிரிக்கெட்", "காப்பீடு", "வங்கி", "பாஸ்போர்ட்",
135
+ "பஸ்", "ரயில்", "விமான நிலையம்", "விசா", "ஹோட்டல்", "உணவகம்",
136
+ "கம்ப்யூட்டர்", "வைஃபை", "வரி", "கடன்", "வேலை", "பல்கலைக்கழகம்",
137
+ "மருத்துவர்", "மருத்துவமனை", "வானிலை", "மதிப்பெண்", "விமானம்", "டிக்கெட்",
138
+ "சம்பளம்", "மொபைல்", "கேமரா","எண்ணெய்","உலகம்"
139
+ }
140
+
141
+
142
+ # -----------------------------
143
+ # Helpers
144
+ # -----------------------------
145
+ def _fallback_text(lang: str) -> str:
146
+ return FALLBACK_TA if lang == "ta" else FALLBACK_SI
147
+
148
+
149
+ def _clarify_text(lang: str) -> str:
150
+ return CLARIFY_TA if lang == "ta" else CLARIFY_SI
151
+
152
+
153
+ def _location_fallback_text(lang: str) -> str:
154
+ return LOCATION_FALLBACK_TA if lang == "ta" else LOCATION_FALLBACK_SI
155
+
156
+
157
+ def _json_response(
158
+ reply: str,
159
+ match_type: str,
160
+ category: str,
161
+ language: str,
162
+ source_id: str = "",
163
+ confidence: float = 0.0,
164
+ answer_source: str = "",
165
+ debug_hits=None,
166
+ ):
167
+ payload = {
168
+ "reply": reply,
169
+ "match_type": match_type,
170
+ "category": category,
171
+ "language": language,
172
+ "source_id": source_id,
173
+ "confidence": round(float(confidence), 4),
174
+ "answer_source": answer_source,
175
+ }
176
+ if DEBUG and debug_hits is not None:
177
+ payload["debug_hits"] = debug_hits
178
+ return JSONResponse(content=payload)
179
+
180
+
181
+ def _contains_any_phrase(text: str, phrases: set[str]) -> bool:
182
+ t = normalize(text).lower()
183
+ phrases_sorted = sorted((p.lower() for p in phrases), key=len, reverse=True)
184
+ return any(p in t for p in phrases_sorted)
185
+
186
+
187
+ def _is_outside_kurunegala(text: str) -> bool:
188
+ t = normalize(text).lower()
189
+
190
+ if _contains_any_phrase(t, KURUNEGALA_TERMS):
191
+ return False
192
+
193
+ if _contains_any_phrase(t, NON_KURUNEGALA_TERMS):
194
+ return True
195
+
196
+ return False
197
+
198
+
199
+ def _is_explicitly_non_domain(text: str) -> bool:
200
+ return _contains_any_phrase(text, NON_DOMAIN_TERMS)
201
+
202
+
203
+ @app.on_event("startup")
204
+ def startup_checks():
205
+ if FALLBACK_THRESHOLD > CLARIFY_THRESHOLD:
206
+ raise ValueError("FALLBACK_THRESHOLD cannot be greater than CLARIFY_THRESHOLD")
207
+
208
+ logger.info(
209
+ {
210
+ "event": "startup",
211
+ "use_fine_tuned_model": USE_FINE_TUNED_MODEL,
212
+ "fallback_threshold": FALLBACK_THRESHOLD,
213
+ "clarify_threshold": CLARIFY_THRESHOLD,
214
+ "debug": DEBUG,
215
+ }
216
+ )
217
+
218
+
219
+ @app.get("/health")
220
+ def health():
221
+ return {
222
+ "status": "ok",
223
+ "use_fine_tuned_model": USE_FINE_TUNED_MODEL,
224
+ "fine_tuned_model_available": generate_grounded_answer is not None,
225
+ "fallback_threshold": FALLBACK_THRESHOLD,
226
+ "clarify_threshold": CLARIFY_THRESHOLD,
227
+ "debug": DEBUG,
228
+ }
229
+
230
+
231
+ if DEBUG:
232
+ @app.get("/test-firestore/{doc_id}")
233
+ def test_firestore(doc_id: str):
234
+ try:
235
+ doc = get_advice_by_id(doc_id)
236
+ if not doc:
237
+ return {"ok": False, "error": "Document not found", "doc_id": doc_id}
238
+ return {"ok": True, "doc_id": doc_id, "doc": doc}
239
+ except Exception as e:
240
+ return {"ok": False, "error": str(e), "doc_id": doc_id}
241
+
242
+
243
+ @app.post("/chat")
244
+ def chat(req: ChatRequest):
245
+ msg = (req.message or "").strip()
246
+ lang = (req.language or "").strip().lower()
247
+
248
+ if lang not in {"si", "ta"}:
249
+ raise HTTPException(status_code=400, detail="Invalid language. Use 'si' or 'ta'.")
250
+
251
+ if not msg:
252
+ return _json_response(
253
+ reply=_clarify_text(lang),
254
+ match_type="fallback",
255
+ category="empty_input",
256
+ language=lang,
257
+ source_id="",
258
+ confidence=0.0,
259
+ answer_source="guard",
260
+ )
261
+
262
+ user_q = normalize(msg)
263
+
264
+ # -----------------------------
265
+ # Smalltalk
266
+ # -----------------------------
267
+ kind = detect_smalltalk(user_q, lang)
268
+ if kind:
269
+ return _json_response(
270
+ reply=smalltalk_reply(kind, lang),
271
+ match_type="smalltalk",
272
+ category="",
273
+ language=lang,
274
+ source_id="",
275
+ confidence=1.0,
276
+ answer_source="smalltalk",
277
+ )
278
+
279
+ # -----------------------------
280
+ # Location guard
281
+ # -----------------------------
282
+ if _is_outside_kurunegala(user_q):
283
+ return _json_response(
284
+ reply=_location_fallback_text(lang),
285
+ match_type="fallback",
286
+ category="out_of_scope_location",
287
+ language=lang,
288
+ source_id="",
289
+ confidence=0.0,
290
+ answer_source="guard",
291
+ )
292
+
293
+ # -----------------------------
294
+ # Explicit non-domain guard
295
+ # -----------------------------
296
+ if _is_explicitly_non_domain(user_q):
297
+ return _json_response(
298
+ reply=_fallback_text(lang),
299
+ match_type="fallback",
300
+ category="out_of_domain",
301
+ language=lang,
302
+ source_id="",
303
+ confidence=0.0,
304
+ answer_source="guard",
305
+ )
306
+
307
+ best = None
308
+ source = ""
309
+ confidence = 0.0
310
+ category = ""
311
+ debug_hits = None
312
+
313
+ # -----------------------------
314
+ # Exact Match
315
+ # -----------------------------
316
+ if lang == "si" and user_q in EXACT_SI:
317
+ best = EXACT_SI[user_q]
318
+ source = "exact"
319
+ confidence = 1.0
320
+
321
+ elif lang == "ta" and user_q in EXACT_TA:
322
+ best = EXACT_TA[user_q]
323
+ source = "exact"
324
+ confidence = 1.0
325
+
326
+ else:
327
+ # -----------------------------
328
+ # Semantic Search
329
+ # -----------------------------
330
+ try:
331
+ hits = search(user_q, lang=lang, k=5)
332
+ except Exception as e:
333
+ logger.exception("Semantic search failed: %s", e)
334
+ return _json_response(
335
+ reply=_fallback_text(lang),
336
+ match_type="error",
337
+ category="system_error",
338
+ language=lang,
339
+ source_id="",
340
+ confidence=0.0,
341
+ answer_source="error",
342
+ )
343
+
344
+ if DEBUG:
345
+ debug_hits = [
346
+ {
347
+ "id": h["id"],
348
+ "score": round(h["score"], 4),
349
+ "category": h["item"].get("category", ""),
350
+ "matched_question": h["matched_question"],
351
+ }
352
+ for h in hits[:3]
353
+ ]
354
+
355
+ if not hits:
356
+ return _json_response(
357
+ reply=_fallback_text(lang),
358
+ match_type="fallback",
359
+ category="unknown",
360
+ language=lang,
361
+ source_id="",
362
+ confidence=0.0,
363
+ answer_source="semantic",
364
+ debug_hits=debug_hits,
365
+ )
366
+
367
+ best_hit = hits[0]
368
+ top = float(best_hit["score"])
369
+ best = best_hit["item"]
370
+ category = best.get("category", "general")
371
+ confidence = top
372
+
373
+ if top < FALLBACK_THRESHOLD:
374
+ return _json_response(
375
+ reply=_fallback_text(lang),
376
+ match_type="fallback",
377
+ category=category,
378
+ language=lang,
379
+ source_id=best_hit.get("id", ""),
380
+ confidence=top,
381
+ answer_source="semantic",
382
+ debug_hits=debug_hits,
383
+ )
384
+
385
+ if FALLBACK_THRESHOLD <= top < CLARIFY_THRESHOLD:
386
+ return _json_response(
387
+ reply=_clarify_text(lang),
388
+ match_type="clarification",
389
+ category=category,
390
+ language=lang,
391
+ source_id=best_hit.get("id", ""),
392
+ confidence=top,
393
+ answer_source="semantic",
394
+ debug_hits=debug_hits,
395
+ )
396
+
397
+ source = "semantic"
398
+
399
+ # -----------------------------
400
+ # Firestore-backed Answer Selection
401
+ # -----------------------------
402
+ doc = None
403
+ source_id = ""
404
+ answer_source = "dataset"
405
+
406
+ if isinstance(best, dict):
407
+ source_id = str(best.get("id", "")).strip()
408
+ category = best.get("category", category)
409
+
410
+ if source_id:
411
+ try:
412
+ doc = get_advice_by_id(source_id)
413
+ except Exception as e:
414
+ logger.exception("Firestore lookup failed for source_id=%s: %s", source_id, e)
415
+ doc = None
416
+
417
+ if doc and isinstance(doc, dict):
418
+ context_answer = doc.get("answer_ta", "") if lang == "ta" else doc.get("answer_si", "")
419
+ category = doc.get("category", category)
420
+ answer_source = "firestore"
421
+ else:
422
+ context_answer = best.get("answer_ta", "") if lang == "ta" else best.get("answer_si", "")
423
+
424
+ if not context_answer:
425
+ return _json_response(
426
+ reply=_fallback_text(lang),
427
+ match_type="fallback",
428
+ category=category or "unknown",
429
+ language=lang,
430
+ source_id=source_id,
431
+ confidence=confidence,
432
+ answer_source=answer_source,
433
+ debug_hits=debug_hits,
434
+ )
435
+
436
+ # -----------------------------
437
+ # Optional Qwen Output Layer
438
+ # -----------------------------
439
+ used_qwen = False
440
+ if USE_FINE_TUNED_MODEL and generate_grounded_answer is not None and source == "semantic":
441
+ try:
442
+ final_reply = generate_grounded_answer(user_q, context_answer, lang)
443
+ used_qwen = True
444
+ except Exception as e:
445
+ logger.exception("Qwen grounded generation failed: %s", e)
446
+ final_reply = context_answer
447
+ else:
448
+ final_reply = context_answer
449
+
450
+ logger.info(
451
+ {
452
+ "message": msg,
453
+ "normalized": user_q,
454
+ "language": lang,
455
+ "match_type": source,
456
+ "source_id": source_id,
457
+ "category": category,
458
+ "confidence": round(confidence, 4),
459
+ "answer_source": answer_source,
460
+ "used_qwen": used_qwen,
461
+ }
462
+ )
463
+
464
+ return _json_response(
465
+ reply=final_reply,
466
+ match_type=source,
467
+ category=category,
468
+ language=lang,
469
+ source_id=source_id,
470
+ confidence=confidence,
471
+ answer_source=answer_source,
472
+ debug_hits=debug_hits,
473
+ )