import os import torch from fastapi import FastAPI from transformers import AutoModelForCausalLM, AutoTokenizer import traceback import re # Set environment variables os.environ["TRITON_DISABLE"] = "1" os.environ["BNB_DISABLE_TRITON"] = "1" os.environ["USE_TORCH"] = "1" os.environ["BITSANDBYTES_NOWELCOME"] = "1" # Create writable temporary cache os.makedirs("/tmp/hf_cache", exist_ok=True) os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" os.environ["TORCH_HOME"] = "/tmp/hf_cache" # FastAPI app app = FastAPI() # Load your FULLY merged model (no adapter references) model_name = "Suguru1846/merged_counseling_model_full" # Your new merged model tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", cache_dir="/tmp/hf_cache" ) @app.post("/generate") async def generate_text(prompt: str, max_tokens: int = 50): try: # Format prompt for Llama models formatted_prompt = f"[INST] {prompt} [/INST]" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9 ) raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up the response - remove the prompt and any remaining tags clean_response = raw_response.replace(formatted_prompt, "").strip() # Remove any remaining instruction tags clean_response = re.sub(r'|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip() return {"response": clean_response} except Exception as e: error_msg = str(e) error_trace = traceback.format_exc() print(f"Error generating text: {error_msg}") print(f"Traceback: {error_trace}") return {"error": error_msg, "traceback": error_trace} @app.get("/") async def root(): return {"message": "Your Custom Counseling Model is Running"}