logbook_ai_gen / app.py
JumaRubea's picture
Update app.py
8485d83 verified
import os
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn
import torch
# Load the model and tokenizer once at startup
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model.to("cuda" if torch.cuda.is_available() else "cpu")
app = FastAPI()
# Request format from frontend
class GenerationRequest(BaseModel):
system_message: str
user_prompt: str
@app.post("/api/ai-generate")
async def generate_text(request: GenerationRequest):
try:
# Format messages to fit chat template
messages = [
{"role": "system", "content": request.system_message},
{"role": "user", "content": request.user_prompt}
]
# Tokenize input for generation
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
# Generate response
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
temperature=0.5,
top_p=0.9,
eos_token_id=None
)
# Decode only the new tokens (skip the prompt)
generated_text = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
return JSONResponse({"generated_text": generated_text.strip()})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})