import os import json import faiss import torch import numpy as np from fastapi import FastAPI from contextlib import asynccontextmanager from pydantic import BaseModel from huggingface_hub import snapshot_download from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM # ───────────────────────────── # CONFIG # ───────────────────────────── MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct" RAG_REPO = "Rady10/Agriculture-Rag-Data-Index" EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" DEVICE = "cpu" MAX_TOKENS = 256 os.environ["TOKENIZERS_PARALLELISM"] = "false" # ───────────────────────────── # GLOBALS # ───────────────────────────── tokenizer = None model = None embedder = None faiss_index = None rag_chunks = None # ───────────────────────────── # SYSTEM PROMPT # ───────────────────────────── SYSTEM_PROMPT = """ You are an agriculture assistant. Answer clearly and concisely in English or Arabic. Focus on plant diseases, pests, irrigation, and farming advice. """ # ───────────────────────────── # FASTAPI LIFESPAN (IMPORTANT) # ───────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): global tokenizer, model, embedder, faiss_index, rag_chunks print("Loading RAG...") rag_dir = snapshot_download( repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag" ) faiss_index = faiss.read_index( os.path.join(rag_dir, "agro.index") ) with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f: rag_chunks = json.load(f) print("Loading embedder...") embedder = SentenceTransformer(EMBED_MODEL) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( MODEL_REPO, trust_remote_code=True ) print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_REPO, device_map="cpu", torch_dtype=torch.float32, trust_remote_code=True ) model.eval() print("ALL LOADED") yield app = FastAPI(lifespan=lifespan) # ───────────────────────────── # REQUEST MODEL # ───────────────────────────── class ChatRequest(BaseModel): message: str # ───────────────────────────── # RAG # ───────────────────────────── def retrieve(query, k=3): if not query: return "" emb = embedder.encode([query], normalize_embeddings=True).astype(np.float32) scores, idxs = faiss_index.search(emb, k) results = [] for score, idx in zip(scores[0], idxs[0]): if idx != -1 and score > 0.3: results.append(rag_chunks[idx]["text"]) return "\n\n".join(results) # ───────────────────────────── # GENERATION # ───────────────────────────── def generate(text): context = retrieve(text) prompt = SYSTEM_PROMPT if context: prompt += "\n\nKnowledge:\n" + context messages = [ {"role": "system", "content": prompt}, {"role": "user", "content": text} ] input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=MAX_TOKENS, temperature=0.7, top_p=0.9 ) return tokenizer.decode( output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) # ───────────────────────────── # API ROUTES # ───────────────────────────── @app.get("/") def home(): return {"status": "running"} @app.post("/chat") def chat(req: ChatRequest): response = generate(req.message) return {"response": response}