llm-project-3 / app.py
jerwng's picture
Update app.py
2ebe90c verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
app = FastAPI()
# Request model
class ChatRequest(BaseModel):
message: str
personality: str = "dwight"
# Load base model and tokenizer
base_model_name = "Qwen/Qwen2-0.5B-Instruct"
print("Loading base model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print("Loading LoRA adapters into a single model...")
model = PeftModel.from_pretrained(
base_model,
"jerwng/dwight2",
adapter_name="dwight",
)
model.load_adapter("jerwng/michael2", adapter_name="michael")
model.load_adapter("jerwng/spongebob2", adapter_name="spongebob")
print("All models loaded successfully!")
# Model personalities
PERSONALITIES = {
"dwight": {
"name": "Dwight Schrute",
"description": "Assistant Regional Manager with beet farm wisdom"
},
"michael": {
"name": "Michael Scott",
"description": "World's Best Boss with unique management style"
},
"spongebob": {
"name": "SpongeBob SquarePants",
"description": "Optimistic fry cook from Bikini Bottom"
}
}
def generate_response(message, personality):
"""
Generate response using fine-tuned LoRA models
"""
if personality not in PERSONALITIES:
personality = "dwight"
# Get the appropriate model
model.set_adapter(personality)
# Format the message as a chat
messages = [
{"role": "user", "content": message}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize and generate
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode the response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "<|im_start|>assistant" in full_response:
response = full_response.split("<|im_start|>assistant")[-1].strip()
response = response.replace("<|im_end|>", "").strip()
else:
response = full_response.split("assistant\n")[-1]
return response
@app.get("/")
async def index():
return FileResponse("index.html")
@app.post("/chat")
async def chat(request: ChatRequest):
try:
if not request.message:
raise HTTPException(status_code=400, detail="Message is required")
# Generate response
response = generate_response(request.message, request.personality)
return {
"response": response,
"personality": request.personality,
"personality_name": PERSONALITIES[request.personality]["name"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=7860)