AdmissionOfficer / src /hf_client.py
Rwan Ashraf
Improving arabic accuracy/adding arabic fall back
6c46177
"""HuggingFace Client for embeddings and text generation"""
import os
import re
from huggingface_hub import InferenceClient
HF_TOKEN = os.getenv("HF_TOKEN")
UNKNOWN_FALLBACK = "Thank you for your inquiry. Unfortunately, I am unable to provide an answer to your question at this time. For accurate and up-to-date information, please contact the admissions office at admission@nu.edu.eg"
ARABIC_UNKNOWN_FALLBACK = "شكرا لسؤالك. للاسف لا استطيع تقديم اجابة على هذا السؤال حاليا. للمعلومات الدقيقة والمحدثة يرجى التواصل مع مكتب القبول عبر admission@nu.edu.eg"
# =========================
# MODELS
# =========================
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
GEN_MODEL = os.getenv("GEN_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
print(f"[OK] Embedding model: {EMBED_MODEL}")
print(f"[OK] Generation model: {GEN_MODEL}")
print(f"[OK] HF_TOKEN: {'Set' if HF_TOKEN else 'Not set (using public API)'}")
# Create clients
embed_client = InferenceClient(token=HF_TOKEN)
gen_client = InferenceClient(token=HF_TOKEN)
def normalize_text(text: str) -> list:
"""Normalize text into lowercase tokens for both English and Arabic."""
text = (text or "").lower()
text = re.sub("[إأآا]", "ا", text)
text = re.sub("ى", "ي", text)
text = re.sub("ؤ", "ء", text)
text = re.sub("ئ", "ء", text)
text = re.sub("ة", "ه", text)
return re.findall(r"[\w\u0600-\u06FF]+", text, flags=re.UNICODE)
def contains_arabic(text: str) -> bool:
return bool(re.search(r"[\u0600-\u06FF]", text or ""))
def get_unknown_fallback(question: str = "") -> str:
return ARABIC_UNKNOWN_FALLBACK if contains_arabic(question) else UNKNOWN_FALLBACK
def enforce_unknown_if_ungrounded(answer: str, context: str, question: str = "") -> str:
"""
Return a safe fallback when the generated answer appears ungrounded.
This is a lightweight lexical check to reduce hallucinations.
"""
fallback = get_unknown_fallback(question)
answer = (answer or "").strip()
if not answer:
return fallback
if answer.lower() in {
"i don't know",
"i dont know",
"unknown",
"thank you for your inquiry. unfortunately, i am unable to provide an answer to your question at this time. for accurate and up-to-date information, please contact the admissions office at admission@nu.edu.eg",
}:
return fallback
answer_tokens = [t for t in normalize_text(answer) if len(t) > 2]
context_tokens = set(t for t in normalize_text(context) if len(t) > 2)
# Arabic answers may be valid translations of English context.
# In that cross-language case, lexical overlap is not a reliable signal.
has_arabic_answer = bool(re.search(r"[\u0600-\u06FF]", answer))
has_arabic_context = bool(re.search(r"[\u0600-\u06FF]", context or ""))
if not answer_tokens or not context_tokens:
return fallback
overlap = sum(1 for t in answer_tokens if t in context_tokens)
overlap_ratio = overlap / max(len(answer_tokens), 1)
# If answer introduces links that are not in retrieved context, mark as ungrounded.
if ("http://" in answer or "https://" in answer or "www." in answer) and not (
"http://" in context or "https://" in context or "www." in context
):
return fallback
if has_arabic_answer and not has_arabic_context:
return answer
if overlap_ratio < 0.2:
return fallback
return answer
# =========================
# Embeddings
# =========================
def embed_texts(texts):
"""Generate embeddings for a list of texts"""
try:
input_texts = [texts] if isinstance(texts, str) else list(texts)
result = embed_client.feature_extraction(
text=input_texts if len(input_texts) > 1 else input_texts[0],
model=EMBED_MODEL,
)
if hasattr(result, "tolist"):
result = result.tolist()
# Normalize response shape to always be List[List[float]].
# Some backends return a single vector even for batched input.
if not isinstance(result, list):
return [result for _ in input_texts]
if not result:
return [[0.0] * 384 for _ in input_texts]
# Single-vector response: [0.1, 0.2, ...]
if isinstance(result[0], (int, float)):
vector = [float(v) for v in result]
return [vector for _ in input_texts]
# Matrix response: [[...], [...], ...]
if isinstance(result[0], list) and result and result[0] and isinstance(result[0][0], (int, float)):
matrix = [[float(v) for v in row] for row in result]
if len(input_texts) == 1:
return [matrix[0] if len(matrix) >= 1 else [0.0] * 384]
if len(result) == len(input_texts):
return matrix
if len(result) == 1:
return [matrix[0] for _ in input_texts]
# Nested/irregular fallback: coerce each embedding row to list of floats.
coerced = []
for row in result:
if hasattr(row, "tolist"):
row = row.tolist()
if isinstance(row, list):
coerced.append([float(v) for v in row])
if len(coerced) == len(input_texts):
return coerced
if len(coerced) == 1:
return [coerced[0] for _ in input_texts]
return [[0.0] * 384 for _ in input_texts]
except Exception as e:
print(f"[WARN] Embedding error: {e}")
input_texts = [texts] if isinstance(texts, str) else list(texts)
return [[0.0] * 384 for _ in input_texts]
# =========================
# Generation
# =========================
def generate_answer(system_prompt, user_prompt, conversation_history=None):
"""
Generate answer using HuggingFace Inference API.
Uses chat_completion for chat models, text_generation for others.
"""
# Try chat_completion first (works with most instruct models)
try:
messages = []
# Add system message
messages.append({"role": "system", "content": system_prompt})
# Add conversation history (last 2 exchanges)
if conversation_history:
for q, a in conversation_history[-2:]:
messages.append({"role": "user", "content": str(q)})
messages.append({"role": "assistant", "content": str(a)})
# Add current question
messages.append({"role": "user", "content": user_prompt})
print(f"[INFO] Generating with {GEN_MODEL} (chat_completion)...")
print(f" Messages: {len(messages)}")
response = gen_client.chat_completion(
messages=messages,
model=GEN_MODEL,
max_tokens=500,
temperature=0.5,
top_p=0.9
)
answer = response.choices[0].message.content.strip()
if answer and len(answer) > 10:
print(f"[OK] Generated: {len(answer)} chars")
return answer
raise Exception("Empty response from chat_completion")
except Exception as e:
print(f"[WARN] chat_completion failed: {e}")
# Fallback to text_generation
try:
# Build prompt in Mistral/Llama format
prompt = f"<s>[INST] {system_prompt}\n\n"
if conversation_history:
for q, a in conversation_history[-2:]:
prompt += f"User: {q}\nAssistant: {a}\n\n"
prompt += f"{user_prompt} [/INST]"
print(f"[WARN] Trying text_generation fallback...")
response = gen_client.text_generation(
prompt,
model=GEN_MODEL,
max_new_tokens=400,
temperature=0.5,
do_sample=True,
return_full_text=False
)
if response and isinstance(response, str):
answer = response.strip()
# Clean up artifacts
if answer.startswith("[/INST]"):
answer = answer.replace("[/INST]", "").strip()
print(f"[OK] text_generation succeeded: {len(answer)} chars")
return answer
raise Exception("Empty response from text_generation")
except Exception as e2:
print(f"[ERR] Both methods failed:")
print(f" chat_completion: {e}")
print(f" text_generation: {e2}")
return UNKNOWN_FALLBACK