Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import time | |
| from typing import List, Optional | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from datasets import load_dataset | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| MODEL_NAME = os.getenv("OPENROUTER_MODEL", "deepseek/deepseek-chat-v3-0324") | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "").strip() | |
| KNOWLEDGE_URL = os.getenv( | |
| "KNOWLEDGE_URL", | |
| "https://huggingface.co/datasets/sharshar1/arabic-hr-rag-dataset/resolve/main/knowledge_chunks.jsonl", | |
| ) | |
| FAQ_URL = os.getenv( | |
| "FAQ_URL", | |
| "https://huggingface.co/datasets/sharshar1/arabic-hr-rag-dataset/resolve/main/faq_pairs.jsonl", | |
| ) | |
| app = FastAPI(title="HR Genie RAG API", version="1.0.0") | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=2000) | |
| session_id: Optional[str] = None | |
| user_id: Optional[int] = None | |
| class SourceItem(BaseModel): | |
| title: str | |
| source_file: str | |
| score: float | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| mode: str | |
| sources: List[SourceItem] | |
| latency_ms: int | |
| search_df: Optional[pd.DataFrame] = None | |
| embedding_model: Optional[SentenceTransformer] = None | |
| faiss_index = None | |
| bm25 = None | |
| def normalize_ar(text: str) -> str: | |
| text = str(text) | |
| text = re.sub(r"[\u064B-\u0652\u0670\u0640]", "", text) | |
| text = re.sub(r"[أإآٱ]", "ا", text) | |
| text = text.replace("ى", "ي").replace("ة", "ه") | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def simple_ar_tokenize(text: str) -> List[str]: | |
| text = normalize_ar(text) | |
| text = re.sub(r"[^\w\s]", " ", text) | |
| return text.split() | |
| def detect_response_style(query: str) -> str: | |
| text = str(query).strip() | |
| arabic_chars = len(re.findall(r"[\u0600-\u06FF]", text)) | |
| latin_chars = len(re.findall(r"[A-Za-z]", text)) | |
| if latin_chars > arabic_chars and latin_chars > 0: | |
| return "english" | |
| if arabic_chars > 0: | |
| egyptian_markers = ["ازاي", "عاوز", "عايز", "ليه", "عامل", "ازيك", "دلوقتي", "ممكن"] | |
| lowered = text.lower() | |
| if any(marker in lowered for marker in egyptian_markers): | |
| return "egyptian_ar" | |
| return "fusha_ar" | |
| return "english" | |
| def style_instruction(style: str) -> str: | |
| if style == "english": | |
| return "Answer in clear natural English." | |
| if style == "egyptian_ar": | |
| return "أجب باللهجة المصرية بشكل طبيعي وواضح، وبأسلوب شات بوت ودود." | |
| return "أجب بالعربية الفصحى بشكل واضح ومهذب." | |
| def source_label_for_style(style: str) -> str: | |
| return "Source" if style == "english" else "المصدر" | |
| def sanitize_model_output(text: str) -> str: | |
| text = str(text or "") | |
| text = re.sub(r"[\u4e00-\u9fff]+", " ", text) | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| return re.sub(r"[ \t]{2,}", " ", text).strip() | |
| def openrouter_chat(messages: List[dict], temperature: float = 0.2, max_tokens: int = 300) -> str: | |
| if not OPENROUTER_API_KEY: | |
| raise RuntimeError("OPENROUTER_API_KEY is missing.") | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": MODEL_NAME, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| response = requests.post(url, headers=headers, json=payload, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| return sanitize_model_output(data["choices"][0]["message"]["content"]) | |
| def load_data() -> pd.DataFrame: | |
| knowledge_ds = load_dataset("json", data_files=KNOWLEDGE_URL)["train"] | |
| faq_ds = load_dataset("json", data_files=FAQ_URL)["train"] | |
| knowledge_df = pd.DataFrame(knowledge_ds) | |
| faq_df = pd.DataFrame(faq_ds) | |
| knowledge_search = knowledge_df.copy() | |
| knowledge_search["search_text"] = ( | |
| "الفئة: " | |
| + knowledge_search["category_ar"].fillna("") | |
| + "\nالعنوان: " | |
| + knowledge_search["topic"].fillna("") | |
| + "\nالمحتوى: " | |
| + knowledge_search["content"].fillna("") | |
| ) | |
| knowledge_search["answer_text"] = knowledge_search["content"] | |
| knowledge_search["display_title"] = knowledge_search["topic"] | |
| knowledge_search["record_type"] = "knowledge" | |
| faq_search = faq_df.copy() | |
| faq_search["search_text"] = ( | |
| "الفئة: " | |
| + faq_search["category_ar"].fillna("") | |
| + "\nالسؤال: " | |
| + faq_search["question"].fillna("") | |
| + "\nالإجابة: " | |
| + faq_search["answer"].fillna("") | |
| ) | |
| faq_search["answer_text"] = faq_search["answer"] | |
| faq_search["display_title"] = faq_search["question"] | |
| faq_search["record_type"] = "faq" | |
| common_cols = [ | |
| "category_ar", | |
| "source_file", | |
| "record_type", | |
| "display_title", | |
| "answer_text", | |
| "search_text", | |
| ] | |
| return pd.concat([knowledge_search[common_cols], faq_search[common_cols]], ignore_index=True) | |
| def build_indexes() -> None: | |
| global search_df, embedding_model, faiss_index, bm25 | |
| search_df = load_data() | |
| embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
| search_texts = search_df["search_text"].apply(normalize_ar).tolist() | |
| vectors = embedding_model.encode( | |
| search_texts, | |
| batch_size=32, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| ).astype("float32") | |
| faiss_index = faiss.IndexFlatIP(vectors.shape[1]) | |
| faiss_index.add(vectors) | |
| bm25_corpus = [simple_ar_tokenize(t) for t in search_df["search_text"].tolist()] | |
| bm25 = BM25Okapi(bm25_corpus) | |
| def retrieve(query: str, top_k: int = 3) -> List[dict]: | |
| if search_df is None or embedding_model is None or faiss_index is None or bm25 is None: | |
| raise RuntimeError("Indexes are not initialized.") | |
| query_norm = normalize_ar(query) | |
| query_vec = embedding_model.encode( | |
| [query_norm], | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| ).astype("float32") | |
| vector_scores, vector_indices = faiss_index.search(query_vec, 8) | |
| bm25_scores = bm25.get_scores(simple_ar_tokenize(query_norm)) | |
| bm25_top_idx = np.argsort(bm25_scores)[::-1][:8] | |
| merged = {} | |
| for rank, (score, idx) in enumerate(zip(vector_scores[0], vector_indices[0]), start=1): | |
| merged[int(idx)] = { | |
| "vector_rank": rank, | |
| "vector_score": float(score), | |
| "bm25_rank": None, | |
| } | |
| for rank, idx in enumerate(bm25_top_idx, start=1): | |
| idx = int(idx) | |
| merged.setdefault( | |
| idx, | |
| { | |
| "vector_rank": None, | |
| "vector_score": 0.0, | |
| "bm25_rank": None, | |
| }, | |
| ) | |
| merged[idx]["bm25_rank"] = rank | |
| items = [] | |
| for idx, score_obj in merged.items(): | |
| v_rank_score = 1.0 / score_obj["vector_rank"] if score_obj["vector_rank"] else 0.0 | |
| b_rank_score = 1.0 / score_obj["bm25_rank"] if score_obj["bm25_rank"] else 0.0 | |
| hybrid_score = 0.65 * v_rank_score + 0.35 * b_rank_score | |
| row = search_df.iloc[idx] | |
| items.append( | |
| { | |
| "display_title": str(row.get("display_title", "")), | |
| "answer_text": str(row.get("answer_text", "")), | |
| "source_file": str(row.get("source_file", "")), | |
| "record_type": str(row.get("record_type", "")), | |
| "category_ar": str(row.get("category_ar", "")), | |
| "score": float(hybrid_score), | |
| } | |
| ) | |
| return sorted(items, key=lambda x: x["score"], reverse=True)[:top_k] | |
| def build_prompt(query: str, chunks: List[dict]) -> str: | |
| style = detect_response_style(query) | |
| response_style_instruction = style_instruction(style) | |
| context = [] | |
| for i, c in enumerate(chunks, start=1): | |
| context.append( | |
| "\n".join( | |
| [ | |
| f"[Source {i}]", | |
| f"Type: {c['record_type']}", | |
| f"Category: {c['category_ar']}", | |
| f"Title: {c['display_title']}", | |
| f"Content: {c['answer_text']}", | |
| f"File: {c['source_file']}", | |
| ] | |
| ) | |
| ) | |
| context_text = "\n\n".join(context) | |
| return ( | |
| "You are an HR smart assistant inside a company.\n\n" | |
| f"{response_style_instruction}\n\n" | |
| "Use the provided sources when they are relevant and enough.\n" | |
| "If the sources are not enough, answer naturally from general knowledge.\n" | |
| "Do not mention retrieval or hidden reasoning.\n" | |
| "Keep answers short, clear, and practical.\n" | |
| "If you used sources, add one final source line in the same language style.\n\n" | |
| f"Sources:\n{context_text}\n\n" | |
| f"User question:\n{query}\n\n" | |
| "Final answer:" | |
| ) | |
| def ask_assistant(query: str) -> dict: | |
| chunks = retrieve(query, top_k=3) | |
| style = detect_response_style(query) | |
| source_label = source_label_for_style(style) | |
| if not chunks: | |
| prompt = ( | |
| "You are a smart and friendly assistant. " | |
| f"{style_instruction(style)} " | |
| "Answer naturally and briefly." | |
| ) | |
| answer = openrouter_chat( | |
| [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": query}, | |
| ], | |
| temperature=0.4, | |
| max_tokens=260, | |
| ) | |
| return {"answer": answer, "mode": "general_chat", "sources": []} | |
| prompt = build_prompt(query, chunks) | |
| answer = openrouter_chat( | |
| [{"role": "user", "content": prompt}], | |
| temperature=0.15, | |
| max_tokens=300, | |
| ) | |
| answer = sanitize_model_output(answer) | |
| if source_label not in answer and chunks[0]["source_file"]: | |
| answer = f"{answer}\n{source_label}\n{chunks[0]['source_file']}" | |
| sources = [ | |
| { | |
| "title": c["display_title"], | |
| "source_file": c["source_file"], | |
| "score": round(c["score"], 4), | |
| } | |
| for c in chunks | |
| ] | |
| return {"answer": answer, "mode": "rag", "sources": sources} | |
| def on_startup() -> None: | |
| build_indexes() | |
| def root() -> dict: | |
| return { | |
| "name": "HR Genie RAG API", | |
| "status": "running", | |
| "endpoints": { | |
| "health": "/health", | |
| "chat": "/chat", | |
| "docs": "/docs", | |
| }, | |
| } | |
| def health() -> dict: | |
| return {"status": "ok", "model": MODEL_NAME, "has_openrouter_key": bool(OPENROUTER_API_KEY)} | |
| def chat(payload: ChatRequest) -> ChatResponse: | |
| start = time.time() | |
| try: | |
| result = ask_assistant(payload.message.strip()) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| latency_ms = int((time.time() - start) * 1000) | |
| sources = [SourceItem(**src) for src in result.get("sources", [])] | |
| return ChatResponse( | |
| answer=result.get("answer", ""), | |
| mode=result.get("mode", "unknown"), | |
| sources=sources, | |
| latency_ms=latency_ms, | |
| ) | |