Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| app = FastAPI() | |
| # Request model | |
| class ChatRequest(BaseModel): | |
| message: str | |
| personality: str = "dwight" | |
| # Load base model and tokenizer | |
| base_model_name = "Qwen/Qwen2-0.5B-Instruct" | |
| print("Loading base model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print("Loading LoRA adapters into a single model...") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| "jerwng/dwight2", | |
| adapter_name="dwight", | |
| ) | |
| model.load_adapter("jerwng/michael2", adapter_name="michael") | |
| model.load_adapter("jerwng/spongebob2", adapter_name="spongebob") | |
| print("All models loaded successfully!") | |
| # Model personalities | |
| PERSONALITIES = { | |
| "dwight": { | |
| "name": "Dwight Schrute", | |
| "description": "Assistant Regional Manager with beet farm wisdom" | |
| }, | |
| "michael": { | |
| "name": "Michael Scott", | |
| "description": "World's Best Boss with unique management style" | |
| }, | |
| "spongebob": { | |
| "name": "SpongeBob SquarePants", | |
| "description": "Optimistic fry cook from Bikini Bottom" | |
| } | |
| } | |
| def generate_response(message, personality): | |
| """ | |
| Generate response using fine-tuned LoRA models | |
| """ | |
| if personality not in PERSONALITIES: | |
| personality = "dwight" | |
| # Get the appropriate model | |
| model.set_adapter(personality) | |
| # Format the message as a chat | |
| messages = [ | |
| {"role": "user", "content": message} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize and generate | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the response | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| if "<|im_start|>assistant" in full_response: | |
| response = full_response.split("<|im_start|>assistant")[-1].strip() | |
| response = response.replace("<|im_end|>", "").strip() | |
| else: | |
| response = full_response.split("assistant\n")[-1] | |
| return response | |
| async def index(): | |
| return FileResponse("index.html") | |
| async def chat(request: ChatRequest): | |
| try: | |
| if not request.message: | |
| raise HTTPException(status_code=400, detail="Message is required") | |
| # Generate response | |
| response = generate_response(request.message, request.personality) | |
| return { | |
| "response": response, | |
| "personality": request.personality, | |
| "personality_name": PERSONALITIES[request.personality]["name"] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run(app, host='0.0.0.0', port=7860) | |