File size: 2,190 Bytes
e261c43
1343cce
 
9628182
71ee5c2
3ad6551
1343cce
3f8d4b3
c8253e3
3f8d4b3
c8253e3
 
f65fd4f
 
 
3f8d4b3
 
 
c8253e3
3f8d4b3
5920a0a
 
ea24031
 
f65fd4f
 
 
b069624
3f8d4b3
 
 
1f4f76d
5920a0a
 
ac6f59d
ea24031
 
9628182
1343cce
3fcfaa2
1343cce
 
 
 
 
f65fd4f
3fcfaa2
 
 
 
 
 
 
 
ac6f59d
71ee5c2
 
 
 
 
5920a0a
 
 
ea24031
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
import traceback
import re

# Set environment variables
os.environ["TRITON_DISABLE"] = "1"
os.environ["BNB_DISABLE_TRITON"] = "1" 
os.environ["USE_TORCH"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

# Create writable temporary cache
os.makedirs("/tmp/hf_cache", exist_ok=True)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["TORCH_HOME"] = "/tmp/hf_cache"

# FastAPI app
app = FastAPI()

# Load your FULLY merged model (no adapter references)
model_name = "Suguru1846/merged_counseling_model_full"  # Your new merged model
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir="/tmp/hf_cache"
)

@app.post("/generate")
async def generate_text(prompt: str, max_tokens: int = 50):
    try:
        # Format prompt for Llama models
        formatted_prompt = f"<s>[INST] {prompt} [/INST]"
        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )
        
        raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up the response - remove the prompt and any remaining tags
        clean_response = raw_response.replace(formatted_prompt, "").strip()
        # Remove any remaining instruction tags
        clean_response = re.sub(r'</?s>|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip()
        
        return {"response": clean_response}
    except Exception as e:
        error_msg = str(e)
        error_trace = traceback.format_exc()
        print(f"Error generating text: {error_msg}")
        print(f"Traceback: {error_trace}")
        return {"error": error_msg, "traceback": error_trace}

@app.get("/")
async def root():
    return {"message": "Your Custom Counseling Model is Running"}