Jainish1808
Uploaded 21-06 (9)
1d39b51
import os
import torch
from fastapi import FastAPI, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Paths
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
LORA_MODEL_DIR = "./lora_model"
QLORA_MODEL_DIR = "./Qlora_model"
ADALORA_MODEL_DIR = "./adalora_model"
cache_dir = "./cache"
# Prompt Template
PROMPT_TEMPLATE = """<|system|>
You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information."
If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge.
Always respond in 2 to 3 short sentences.
<|user|>
{prompt}
<|assistant|>
"""
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Global cache to avoid reloading models
model_cache = {}
def load_model(adapter_path):
if adapter_path in model_cache:
return model_cache[adapter_path]
print(f"🔄 Loading model from: {adapter_path}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
cache_dir=cache_dir,
)
model = PeftModel.from_pretrained(base, adapter_path)
model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
model_cache[adapter_path] = (tokenizer, model)
return tokenizer, model
def generate_response(prompt, tokenizer, model):
full_prompt = PROMPT_TEMPLATE.format(prompt=prompt)
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded.split("<|assistant|>")[-1].strip() if "<|assistant|>" in decoded else decoded.strip()
@app.get("/", response_class=HTMLResponse)
async def form_get(request: Request):
return templates.TemplateResponse("index.html", {
"request": request,
"result": None,
"model": "",
"prompt": "",
"data_count": 0
})
@app.post("/", response_class=HTMLResponse)
async def form_post(
request: Request,
prompt: str = Form(...),
model_type: str = Form(...)
):
model_paths = {
"lora": LORA_MODEL_DIR,
"Qlora1": QLORA_MODEL_DIR,
"adalora": ADALORA_MODEL_DIR
}
model_labels = {
"lora": "LoRA - lora-tinyllama-final",
"Qlora1": "QLoRA - lora-tinyllama-final1",
"adalora": "AdaLoRA - adalora-tinyllama-final"
}
adapter_path = model_paths.get(model_type)
model_label = model_labels.get(model_type, model_type.upper())
if not adapter_path or not os.path.exists(adapter_path):
return templates.TemplateResponse("index.html", {
"request": request,
"result": "Invalid or missing model selected.",
"model": model_label,
"prompt": prompt,
"data_count": 0
})
try:
tokenizer, model = load_model(adapter_path)
result = generate_response(prompt, tokenizer, model)
except Exception as e:
result = f"Error generating response: {str(e)}"
return templates.TemplateResponse("index.html", {
"request": request,
"result": result,
"model": model_label,
"prompt": prompt,
"data_count": 0 # Replace with real data count if available
})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)