| | import torch |
| | from fastapi import FastAPI, Request |
| | from fastapi.responses import JSONResponse |
| | import uvicorn |
| |
|
| | from transformers import T5ForConditionalGeneration, T5Tokenizer |
| | from peft import PeftModel |
| |
|
| | |
| | BASE_MODEL_NAME = "google/flan-t5-large" |
| | OUTPUT_DIR = "./lora_t5xl_finetuned_8bit/checkpoint-5745" |
| | MAX_SOURCE_LENGTH = 1024 |
| | MAX_TARGET_LENGTH = 1024 |
| |
|
| | |
| | tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_NAME, low_cpu_mem_usage=True) |
| | base_model = T5ForConditionalGeneration.from_pretrained( |
| | BASE_MODEL_NAME, |
| | device_map="auto", |
| | low_cpu_mem_usage=True |
| | ) |
| |
|
| | |
| | model = PeftModel.from_pretrained(base_model, OUTPUT_DIR) |
| | model.eval() |
| |
|
| | |
| | def generate_text(prompt: str) -> str: |
| | """ |
| | Given an input prompt, generate text using the fine-tuned T5-large LoRA model. |
| | """ |
| | input_text = "Humanize this text to be undetectable: " + prompt |
| | inputs = tokenizer( |
| | input_text, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=MAX_SOURCE_LENGTH, |
| | padding="max_length" |
| | ) |
| | |
| | inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| | |
| | |
| | outputs = model.generate( |
| | **inputs, |
| | max_length=MAX_TARGET_LENGTH, |
| | do_sample=True, |
| | top_p=0.95, |
| | temperature=0.9, |
| | num_return_sequences=1 |
| | ) |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return generated_text |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | @app.post("/predict") |
| | async def predict(request: Request): |
| | """ |
| | Expects a JSON payload with a "prompt" field. |
| | Returns the generated text. |
| | """ |
| | data = await request.json() |
| | prompt = data.get("prompt", "") |
| | if not prompt: |
| | return JSONResponse(status_code=400, content={"error": "No prompt provided."}) |
| | |
| | output_text = generate_text(prompt) |
| | return {"generated_text": output_text} |
| |
|
| | |
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|