Spaces:
Paused
Paused
| # app_vllm.py - Faster inference using vLLM | |
| import os | |
| import spaces | |
| import gradio as gr | |
| from vllm import LLM, SamplingParams | |
| from vllm.lora.request import LoRARequest | |
| from transformers import AutoTokenizer | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| PEFT_MODEL_ID = "befm/Be.FM-8B" | |
| # Use /data for persistent storage to avoid re-downloading models | |
| CACHE_DIR = "/data" if os.path.exists("/data") else None | |
| def load_model(): | |
| if HF_TOKEN is None: | |
| raise RuntimeError( | |
| "HF_TOKEN is not set. Add it in Space → Settings → Secrets. " | |
| "Also ensure your account has access to the gated base model." | |
| ) | |
| # Initialize vLLM with PEFT support | |
| llm = LLM( | |
| model=BASE_MODEL_ID, | |
| tokenizer=BASE_MODEL_ID, | |
| enable_lora=True, | |
| max_lora_rank=64, | |
| dtype="float16", | |
| gpu_memory_utilization=0.7, # Reduced from 0.9 to avoid OOM on T4 GPU | |
| trust_remote_code=True, | |
| download_dir=CACHE_DIR, # Use persistent storage | |
| ) | |
| print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}") | |
| print(f"[INFO] Using cache directory: {CACHE_DIR}") | |
| # Load PEFT adapter | |
| lora_request = LoRARequest( | |
| lora_name="befm", | |
| lora_int_id=1, | |
| lora_path=PEFT_MODEL_ID, | |
| ) | |
| print(f"[INFO] PEFT adapter prepared: {PEFT_MODEL_ID}") | |
| return llm, lora_request | |
| # Lazy load model and tokenizer | |
| _llm = None | |
| _lora_request = None | |
| _tokenizer = None | |
| def get_model_and_tokenizer(): | |
| global _llm, _lora_request, _tokenizer | |
| if _llm is None: | |
| _llm, _lora_request = load_model() | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL_ID, | |
| token=HF_TOKEN, | |
| cache_dir=CACHE_DIR # Use persistent storage | |
| ) | |
| return _llm, _lora_request, _tokenizer | |
| def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str: | |
| llm, lora_request, tokenizer = get_model_and_tokenizer() | |
| # Apply Llama 3.1 chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| sampling_params = SamplingParams( | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_new_tokens, | |
| ) | |
| # Generate with vLLM | |
| outputs = llm.generate( | |
| prompts=[prompt], | |
| sampling_params=sampling_params, | |
| lora_request=lora_request, | |
| ) | |
| return outputs[0].outputs[0].text | |
| def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p): | |
| # Build conversation in Llama 3.1 chat format | |
| messages = [] | |
| # Add system prompt (use default if not provided) | |
| if not system_prompt: | |
| system_prompt = "You are Be.FM, a helpful and knowledgeable AI assistant. Provide clear, accurate, and concise responses." | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # History is already in dict format: [{"role": "user", "content": "..."}, ...] | |
| for msg in (history or []): | |
| messages.append(msg) | |
| if message: | |
| messages.append({"role": "user", "content": message}) | |
| reply = generate_response( | |
| messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return reply | |
| demo = gr.ChatInterface( | |
| fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p: | |
| chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p), | |
| additional_inputs=[ | |
| gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2), | |
| gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), | |
| gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), | |
| gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"), | |
| ], | |
| title="Be.FM-8B (vLLM)", | |
| description="Chat interface using vLLM for optimized inference with Meta-Llama-3.1-8B-Instruct and PEFT adapter befm/Be.FM-8B." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |