Suguru1846 commited on
Commit
d73377c
·
verified ·
1 Parent(s): 2d4e42f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from fastapi import FastAPI
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import traceback
6
+ import re
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+
9
+ # Set environment variables
10
+ os.environ["TRITON_DISABLE"] = "1"
11
+ os.environ["BNB_DISABLE_TRITON"] = "1"
12
+ os.environ["USE_TORCH"] = "1"
13
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
14
+
15
+ # Create writable temporary cache
16
+ os.makedirs("/tmp/hf_cache", exist_ok=True)
17
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
18
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
19
+ os.environ["TORCH_HOME"] = "/tmp/hf_cache"
20
+
21
+ # FastAPI app
22
+ app = FastAPI()
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"], # In production, replace with your app's domain
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ # Load your FULLY merged model (no adapter references)
33
+ model_name = "meta-llama/Llama-3.2-3B-Instruct # Your new merged model
34
+ print("Loading model and tokenizer...")
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ torch_dtype=torch.float16, # Use fp16 for better performance
39
+ device_map="auto", # Automatically use available devices
40
+ low_cpu_mem_usage=True # Optimize memory usage
41
+ )
42
+ print("Model and tokenizer loaded successfully!")
43
+
44
+ @app.post("/generate")
45
+ async def generate_text(prompt: str, max_tokens: int = 50):
46
+ try:
47
+ # Format prompt for Llama models
48
+ formatted_prompt = f"<s>[INST] {prompt} [/INST]"
49
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
50
+ outputs = model.generate(
51
+ **inputs,
52
+ max_new_tokens=max_tokens,
53
+ do_sample=True,
54
+ temperature=0.7,
55
+ top_p=0.9
56
+ )
57
+
58
+ raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
+
60
+ # Clean up the response - remove the prompt and any remaining tags
61
+ clean_response = raw_response.replace(formatted_prompt, "").strip()
62
+ # Remove any remaining instruction tags
63
+ clean_response = re.sub(r'</?s>|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip()
64
+
65
+ return {"response": clean_response}
66
+ except Exception as e:
67
+ error_msg = str(e)
68
+ error_trace = traceback.format_exc()
69
+ print(f"Error generating text: {error_msg}")
70
+ print(f"Traceback: {error_trace}")
71
+ return {"error": error_msg, "traceback": error_trace}
72
+
73
+ @app.get("/")
74
+ async def root():
75
+ return {"message": "Your Custom Counseling Model is Running"}