Suguru1846 commited on
Commit
9628182
·
verified ·
1 Parent(s): 3a5fc2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import torch
3
  from fastapi import FastAPI
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
 
6
  # Set environment variables
7
  os.environ["TRITON_DISABLE"] = "1"
@@ -18,8 +18,8 @@ os.environ["TORCH_HOME"] = "/tmp/hf_cache"
18
  # FastAPI app
19
  app = FastAPI()
20
 
21
- # Try loading a completely different model
22
- model_name = "facebook/opt-350m" # Much smaller, more compatible model
23
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_name,
@@ -31,7 +31,10 @@ model = AutoModelForCausalLM.from_pretrained(
31
  @app.post("/generate")
32
  async def generate_text(prompt: str, max_tokens: int = 50):
33
  try:
34
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
35
  outputs = model.generate(
36
  **inputs,
37
  max_new_tokens=max_tokens,
@@ -48,4 +51,4 @@ async def generate_text(prompt: str, max_tokens: int = 50):
48
 
49
  @app.get("/")
50
  async def root():
51
- return {"message": "Model is Running"}
 
1
  import os
2
  import torch
3
  from fastapi import FastAPI
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # Set environment variables
7
  os.environ["TRITON_DISABLE"] = "1"
 
18
  # FastAPI app
19
  app = FastAPI()
20
 
21
+ # Load your merged model
22
+ model_name = "Suguru1846/counseling_model_merged" # Your merged model
23
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_name,
 
31
  @app.post("/generate")
32
  async def generate_text(prompt: str, max_tokens: int = 50):
33
  try:
34
+ # Format prompt for Llama models
35
+ formatted_prompt = f"<s>[INST] {prompt} [/INST]"
36
+
37
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
38
  outputs = model.generate(
39
  **inputs,
40
  max_new_tokens=max_tokens,
 
51
 
52
  @app.get("/")
53
  async def root():
54
+ return {"message": "Your Custom Counseling Model is Running"}