TalkToMe / app.py
Suguru1846's picture
Update app.py
3ad6551 verified
import os
import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
import traceback
import re
# Set environment variables
os.environ["TRITON_DISABLE"] = "1"
os.environ["BNB_DISABLE_TRITON"] = "1"
os.environ["USE_TORCH"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
# Create writable temporary cache
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["TORCH_HOME"] = "/tmp/hf_cache"
# FastAPI app
app = FastAPI()
# Load your FULLY merged model (no adapter references)
model_name = "Suguru1846/merged_counseling_model_full" # Your new merged model
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
cache_dir="/tmp/hf_cache"
)
@app.post("/generate")
async def generate_text(prompt: str, max_tokens: int = 50):
try:
# Format prompt for Llama models
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9
)
raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response - remove the prompt and any remaining tags
clean_response = raw_response.replace(formatted_prompt, "").strip()
# Remove any remaining instruction tags
clean_response = re.sub(r'</?s>|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip()
return {"response": clean_response}
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
print(f"Error generating text: {error_msg}")
print(f"Traceback: {error_trace}")
return {"error": error_msg, "traceback": error_trace}
@app.get("/")
async def root():
return {"message": "Your Custom Counseling Model is Running"}