Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI | |
| from contextlib import asynccontextmanager | |
| from pydantic import BaseModel | |
| from huggingface_hub import snapshot_download | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # βββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββ | |
| MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct" | |
| RAG_REPO = "Rady10/Agriculture-Rag-Data-Index" | |
| EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| DEVICE = "cpu" | |
| MAX_TOKENS = 256 | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # βββββββββββββββββββββββββββββ | |
| # GLOBALS | |
| # βββββββββββββββββββββββββββββ | |
| tokenizer = None | |
| model = None | |
| embedder = None | |
| faiss_index = None | |
| rag_chunks = None | |
| # βββββββββββββββββββββββββββββ | |
| # SYSTEM PROMPT | |
| # βββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """ | |
| You are an agriculture assistant. | |
| Answer clearly and concisely in English or Arabic. | |
| Focus on plant diseases, pests, irrigation, and farming advice. | |
| """ | |
| # βββββββββββββββββββββββββββββ | |
| # FASTAPI LIFESPAN (IMPORTANT) | |
| # βββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| global tokenizer, model, embedder, faiss_index, rag_chunks | |
| print("Loading RAG...") | |
| rag_dir = snapshot_download( | |
| repo_id=RAG_REPO, | |
| repo_type="dataset", | |
| local_dir="./rag" | |
| ) | |
| faiss_index = faiss.read_index( | |
| os.path.join(rag_dir, "agro.index") | |
| ) | |
| with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f: | |
| rag_chunks = json.load(f) | |
| print("Loading embedder...") | |
| embedder = SentenceTransformer(EMBED_MODEL) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_REPO, | |
| trust_remote_code=True | |
| ) | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| print("ALL LOADED") | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| # βββββββββββββββββββββββββββββ | |
| # REQUEST MODEL | |
| # βββββββββββββββββββββββββββββ | |
| class ChatRequest(BaseModel): | |
| message: str | |
| # βββββββββββββββββββββββββββββ | |
| # RAG | |
| # βββββββββββββββββββββββββββββ | |
| def retrieve(query, k=3): | |
| if not query: | |
| return "" | |
| emb = embedder.encode([query], normalize_embeddings=True).astype(np.float32) | |
| scores, idxs = faiss_index.search(emb, k) | |
| results = [] | |
| for score, idx in zip(scores[0], idxs[0]): | |
| if idx != -1 and score > 0.3: | |
| results.append(rag_chunks[idx]["text"]) | |
| return "\n\n".join(results) | |
| # βββββββββββββββββββββββββββββ | |
| # GENERATION | |
| # βββββββββββββββββββββββββββββ | |
| def generate(text): | |
| context = retrieve(text) | |
| prompt = SYSTEM_PROMPT | |
| if context: | |
| prompt += "\n\nKnowledge:\n" + context | |
| messages = [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": text} | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_TOKENS, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| return tokenizer.decode( | |
| output[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # βββββββββββββββββββββββββββββ | |
| # API ROUTES | |
| # βββββββββββββββββββββββββββββ | |
| def home(): | |
| return {"status": "running"} | |
| def chat(req: ChatRequest): | |
| response = generate(req.message) | |
| return {"response": response} |