Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| import torch | |
| import os | |
| app = FastAPI() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| print("HF_TOKEN loaded:", HF_TOKEN is not None) | |
| BASE_MODEL = "google/gemma-2b-it" | |
| LORA_MODEL = "varshithkumar/gemma-finetuned-sql" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Using device:", device) | |
| print("Loading base model with 4-bit quantization...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, # Use 4-bit | |
| bnb_4bit_compute_dtype=torch.float16, # Compute in float16 | |
| bnb_4bit_use_double_quant=True # Optional, better accuracy | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| use_auth_token=HF_TOKEN | |
| ) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL, | |
| use_fast=True, | |
| use_auth_token=HF_TOKEN | |
| ) | |
| print("Applying LoRA adapter...") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| LORA_MODEL, | |
| use_auth_token=HF_TOKEN, | |
| device_map="auto" # ensure LoRA is loaded on the right device | |
| ) | |
| model.to(device) | |
| print("Model loaded successfully!") | |
| class InputData(BaseModel): | |
| prompt: str | |
| max_length: int = 256 # default max length if not provided | |
| def generate_text(data: InputData): | |
| inputs = tokenizer(data.prompt, return_tensors="pt").to(device) | |
| outputs = model.generate(**inputs, max_length=data.max_length) | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return {"response": text} | |