from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from pydantic import BaseModel import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel app = FastAPI() # Request model class ChatRequest(BaseModel): message: str personality: str = "dwight" # Load base model and tokenizer base_model_name = "Qwen/Qwen2-0.5B-Instruct" print("Loading base model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) print("Loading LoRA adapters into a single model...") model = PeftModel.from_pretrained( base_model, "jerwng/dwight2", adapter_name="dwight", ) model.load_adapter("jerwng/michael2", adapter_name="michael") model.load_adapter("jerwng/spongebob2", adapter_name="spongebob") print("All models loaded successfully!") # Model personalities PERSONALITIES = { "dwight": { "name": "Dwight Schrute", "description": "Assistant Regional Manager with beet farm wisdom" }, "michael": { "name": "Michael Scott", "description": "World's Best Boss with unique management style" }, "spongebob": { "name": "SpongeBob SquarePants", "description": "Optimistic fry cook from Bikini Bottom" } } def generate_response(message, personality): """ Generate response using fine-tuned LoRA models """ if personality not in PERSONALITIES: personality = "dwight" # Get the appropriate model model.set_adapter(personality) # Format the message as a chat messages = [ {"role": "user", "content": message} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize and generate inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode the response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response if "<|im_start|>assistant" in full_response: response = full_response.split("<|im_start|>assistant")[-1].strip() response = response.replace("<|im_end|>", "").strip() else: response = full_response.split("assistant\n")[-1] return response @app.get("/") async def index(): return FileResponse("index.html") @app.post("/chat") async def chat(request: ChatRequest): try: if not request.message: raise HTTPException(status_code=400, detail="Message is required") # Generate response response = generate_response(request.message, request.personality) return { "response": response, "personality": request.personality, "personality_name": PERSONALITIES[request.personality]["name"] } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=7860)