rag-mode-api / app.py
Rady10's picture
Create app.py
c22088c verified
import os
import json
import faiss
import torch
import numpy as np
from fastapi import FastAPI
from contextlib import asynccontextmanager
from pydantic import BaseModel
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
# ─────────────────────────────
# CONFIG
# ─────────────────────────────
MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct"
RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
DEVICE = "cpu"
MAX_TOKENS = 256
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ─────────────────────────────
# GLOBALS
# ─────────────────────────────
tokenizer = None
model = None
embedder = None
faiss_index = None
rag_chunks = None
# ─────────────────────────────
# SYSTEM PROMPT
# ─────────────────────────────
SYSTEM_PROMPT = """
You are an agriculture assistant.
Answer clearly and concisely in English or Arabic.
Focus on plant diseases, pests, irrigation, and farming advice.
"""
# ─────────────────────────────
# FASTAPI LIFESPAN (IMPORTANT)
# ─────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global tokenizer, model, embedder, faiss_index, rag_chunks
print("Loading RAG...")
rag_dir = snapshot_download(
repo_id=RAG_REPO,
repo_type="dataset",
local_dir="./rag"
)
faiss_index = faiss.read_index(
os.path.join(rag_dir, "agro.index")
)
with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
rag_chunks = json.load(f)
print("Loading embedder...")
embedder = SentenceTransformer(EMBED_MODEL)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_REPO,
trust_remote_code=True
)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True
)
model.eval()
print("ALL LOADED")
yield
app = FastAPI(lifespan=lifespan)
# ─────────────────────────────
# REQUEST MODEL
# ─────────────────────────────
class ChatRequest(BaseModel):
message: str
# ─────────────────────────────
# RAG
# ─────────────────────────────
def retrieve(query, k=3):
if not query:
return ""
emb = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, idxs = faiss_index.search(emb, k)
results = []
for score, idx in zip(scores[0], idxs[0]):
if idx != -1 and score > 0.3:
results.append(rag_chunks[idx]["text"])
return "\n\n".join(results)
# ─────────────────────────────
# GENERATION
# ─────────────────────────────
def generate(text):
context = retrieve(text)
prompt = SYSTEM_PROMPT
if context:
prompt += "\n\nKnowledge:\n" + context
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": text}
]
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=MAX_TOKENS,
temperature=0.7,
top_p=0.9
)
return tokenizer.decode(
output[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# ─────────────────────────────
# API ROUTES
# ─────────────────────────────
@app.get("/")
def home():
return {"status": "running"}
@app.post("/chat")
def chat(req: ChatRequest):
response = generate(req.message)
return {"response": response}