Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from fastapi import FastAPI, Form, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| # Paths | |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| LORA_MODEL_DIR = "./lora_model" | |
| QLORA_MODEL_DIR = "./Qlora_model" | |
| ADALORA_MODEL_DIR = "./adalora_model" | |
| cache_dir = "./cache" | |
| # Prompt Template | |
| PROMPT_TEMPLATE = """<|system|> | |
| You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information." | |
| If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge. | |
| Always respond in 2 to 3 short sentences. | |
| <|user|> | |
| {prompt} | |
| <|assistant|> | |
| """ | |
| app = FastAPI() | |
| templates = Jinja2Templates(directory="templates") | |
| # Global cache to avoid reloading models | |
| model_cache = {} | |
| def load_model(adapter_path): | |
| if adapter_path in model_cache: | |
| return model_cache[adapter_path] | |
| print(f"🔄 Loading model from: {adapter_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| cache_dir=cache_dir, | |
| ) | |
| model = PeftModel.from_pretrained(base, adapter_path) | |
| model.to("cuda" if torch.cuda.is_available() else "cpu").eval() | |
| model_cache[adapter_path] = (tokenizer, model) | |
| return tokenizer, model | |
| def generate_response(prompt, tokenizer, model): | |
| full_prompt = PROMPT_TEMPLATE.format(prompt=prompt) | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=50, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return decoded.split("<|assistant|>")[-1].strip() if "<|assistant|>" in decoded else decoded.strip() | |
| async def form_get(request: Request): | |
| return templates.TemplateResponse("index.html", { | |
| "request": request, | |
| "result": None, | |
| "model": "", | |
| "prompt": "", | |
| "data_count": 0 | |
| }) | |
| async def form_post( | |
| request: Request, | |
| prompt: str = Form(...), | |
| model_type: str = Form(...) | |
| ): | |
| model_paths = { | |
| "lora": LORA_MODEL_DIR, | |
| "Qlora1": QLORA_MODEL_DIR, | |
| "adalora": ADALORA_MODEL_DIR | |
| } | |
| model_labels = { | |
| "lora": "LoRA - lora-tinyllama-final", | |
| "Qlora1": "QLoRA - lora-tinyllama-final1", | |
| "adalora": "AdaLoRA - adalora-tinyllama-final" | |
| } | |
| adapter_path = model_paths.get(model_type) | |
| model_label = model_labels.get(model_type, model_type.upper()) | |
| if not adapter_path or not os.path.exists(adapter_path): | |
| return templates.TemplateResponse("index.html", { | |
| "request": request, | |
| "result": "Invalid or missing model selected.", | |
| "model": model_label, | |
| "prompt": prompt, | |
| "data_count": 0 | |
| }) | |
| try: | |
| tokenizer, model = load_model(adapter_path) | |
| result = generate_response(prompt, tokenizer, model) | |
| except Exception as e: | |
| result = f"Error generating response: {str(e)}" | |
| return templates.TemplateResponse("index.html", { | |
| "request": request, | |
| "result": result, | |
| "model": model_label, | |
| "prompt": prompt, | |
| "data_count": 0 # Replace with real data count if available | |
| }) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |