Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from fastapi import FastAPI | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import traceback | |
| import re | |
| # Set environment variables | |
| os.environ["TRITON_DISABLE"] = "1" | |
| os.environ["BNB_DISABLE_TRITON"] = "1" | |
| os.environ["USE_TORCH"] = "1" | |
| os.environ["BITSANDBYTES_NOWELCOME"] = "1" | |
| # Create writable temporary cache | |
| os.makedirs("/tmp/hf_cache", exist_ok=True) | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["TORCH_HOME"] = "/tmp/hf_cache" | |
| # FastAPI app | |
| app = FastAPI() | |
| # Load your FULLY merged model (no adapter references) | |
| model_name = "Suguru1846/merged_counseling_model_full" # Your new merged model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| cache_dir="/tmp/hf_cache" | |
| ) | |
| async def generate_text(prompt: str, max_tokens: int = 50): | |
| try: | |
| # Format prompt for Llama models | |
| formatted_prompt = f"<s>[INST] {prompt} [/INST]" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up the response - remove the prompt and any remaining tags | |
| clean_response = raw_response.replace(formatted_prompt, "").strip() | |
| # Remove any remaining instruction tags | |
| clean_response = re.sub(r'</?s>|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip() | |
| return {"response": clean_response} | |
| except Exception as e: | |
| error_msg = str(e) | |
| error_trace = traceback.format_exc() | |
| print(f"Error generating text: {error_msg}") | |
| print(f"Traceback: {error_trace}") | |
| return {"error": error_msg, "traceback": error_trace} | |
| async def root(): | |
| return {"message": "Your Custom Counseling Model is Running"} |