Spaces:
Sleeping
Sleeping
File size: 2,190 Bytes
e261c43 1343cce 9628182 71ee5c2 3ad6551 1343cce 3f8d4b3 c8253e3 3f8d4b3 c8253e3 f65fd4f 3f8d4b3 c8253e3 3f8d4b3 5920a0a ea24031 f65fd4f b069624 3f8d4b3 1f4f76d 5920a0a ac6f59d ea24031 9628182 1343cce 3fcfaa2 1343cce f65fd4f 3fcfaa2 ac6f59d 71ee5c2 5920a0a ea24031 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import os
import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
import traceback
import re
# Set environment variables
os.environ["TRITON_DISABLE"] = "1"
os.environ["BNB_DISABLE_TRITON"] = "1"
os.environ["USE_TORCH"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
# Create writable temporary cache
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["TORCH_HOME"] = "/tmp/hf_cache"
# FastAPI app
app = FastAPI()
# Load your FULLY merged model (no adapter references)
model_name = "Suguru1846/merged_counseling_model_full" # Your new merged model
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
cache_dir="/tmp/hf_cache"
)
@app.post("/generate")
async def generate_text(prompt: str, max_tokens: int = 50):
try:
# Format prompt for Llama models
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9
)
raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response - remove the prompt and any remaining tags
clean_response = raw_response.replace(formatted_prompt, "").strip()
# Remove any remaining instruction tags
clean_response = re.sub(r'</?s>|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip()
return {"response": clean_response}
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
print(f"Error generating text: {error_msg}")
print(f"Traceback: {error_trace}")
return {"error": error_msg, "traceback": error_trace}
@app.get("/")
async def root():
return {"message": "Your Custom Counseling Model is Running"} |