|
|
""" |
|
|
HuggingFace Space: Small LLM |
|
|
Runs Phi-2 or similar small model on ZeroGPU |
|
|
""" |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from typing import Optional |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
app = FastAPI( |
|
|
title="Small LLM Space", |
|
|
description="Small LLM inference (Phi-2)" |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_NAME = "microsoft/phi-2" |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Lazy load the model""" |
|
|
global model, tokenizer |
|
|
|
|
|
if model is None: |
|
|
print(f"Loading {MODEL_NAME}...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
print(f"Model loaded on {next(model.parameters()).device}") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
max_tokens: int = 200 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.9 |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
text: str |
|
|
tokens_generated: int |
|
|
model: str |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"status": "running", |
|
|
"service": "llm", |
|
|
"model": MODEL_NAME, |
|
|
"gpu": torch.cuda.is_available() |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/api/generate", response_model=GenerateResponse) |
|
|
async def generate(request: GenerateRequest): |
|
|
"""Generate text completion""" |
|
|
|
|
|
try: |
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
request.prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=2048 |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_p=request.top_p, |
|
|
do_sample=request.temperature > 0, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode( |
|
|
outputs[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
input_length = inputs["input_ids"].shape[1] |
|
|
output_length = outputs.shape[1] |
|
|
tokens_generated = output_length - input_length |
|
|
|
|
|
return GenerateResponse( |
|
|
text=generated_text, |
|
|
tokens_generated=tokens_generated, |
|
|
model=MODEL_NAME |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return GenerateResponse( |
|
|
text="", |
|
|
tokens_generated=0, |
|
|
model=MODEL_NAME, |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
generate = spaces.GPU(generate) |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def gradio_interface(): |
|
|
import gradio as gr |
|
|
|
|
|
def generate_wrapper(prompt, max_tokens, temperature): |
|
|
from asyncio import run |
|
|
response = run(generate(GenerateRequest( |
|
|
prompt=prompt, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature |
|
|
))) |
|
|
return response.text or f"Error: {response.error}" |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=generate_wrapper, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=5, label="Prompt"), |
|
|
gr.Slider(50, 500, value=200, label="Max Tokens"), |
|
|
gr.Slider(0.0, 1.5, value=0.7, label="Temperature") |
|
|
], |
|
|
outputs=gr.Textbox(lines=10, label="Generated Text"), |
|
|
title="Small LLM (Phi-2)", |
|
|
description="Generate text using Phi-2 model" |
|
|
) |
|
|
|
|
|
return iface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|